# Modular > Deploy fast and scalable GenAI inference This file contains all documentation content in a single document following the llmstxt.org standard. ## Attention mask An attention mask is a mechanism used in the [attention](attention.mdx) layers of a [transformer](transformer.mdx) model to indicate which tokens the model should ignore when computing attention scores. For example, attention masks can prevent the model from attending to [padding tokens](padding-tokens.mdx), which are added to make sequences in a batch the same length and thus offer no information for attention. Another common mask is a "causal mask" (or "look-ahead mask"), which prevents the [self-attention](self-attention) layer from looking at future tokens when predicting a new token, ensuring that it attends only to previous tokens in the sequence. Although it sounds absurd that it would even try to look at future tokens (because it's generating tokens one at a time, in order), the self-attention is designed for more general-purpose attention scoring. In its most basic form, self-attention is agnostic to token order—it looks at all tokens in the sequence equally, based on their embeddings, and calculates scores by looking both backward and ahead in the sequence. (For example, self-attention is used during [context encoding](context-encoding.mdx) to establish an understanding of the input text.) So instead of creating a different kind of attention mechanism for autoregressive inference, the causal mask instructs the self-attention layer to simply ignore all future tokens and only look backward when generating scores that help predict the next token. --- ## Attention A mechanism used in AI models such as [transformers](transformer.mdx) that enables the model to selectively focus on different parts of the input sequence when making predictions. Unlike traditional model architectures that process all input data with equal importance, models with attention assign different importance levels to different tokens (such as words or pixels). This allows the model to better understand the complete meaning of the input, especially when an accurate meaning depends on relationships between tokens that are far apart (such as between words that occur far apart in a sentence). Attention is crucial for large language models (LLMs) so they can capture long-range dependencies and contextual relationships in the given text. It allows LLMs to handle complex and nuanced language, enabling them to generate coherent and contextually relevant output even when the input text includes nuanced references to other parts of the text. Attention was introduced and refined in the papers [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) (Bahdanau et al., 2014) and [Effective Approaches to Attention-based Neural Machine Translation ](https://arxiv.org/abs/1508.04025) (Luong et al., 2015). The most well-known form of attention is [self-attention](self-attention.mdx), in which each token gets its own attention score for every other token (each token "attends to" all other tokens), in order to determine the relative importance of each other token in that context. ## Implementation details The classic attention operation follows this general structure:
It consists of the following operations (`bmm` is short for batched matrix multiplication): * `bmm`: `Q x Transpose(K)` where `Q`, `K` both have shape `[batchSize, numHeads, S, d]` and `Q x K^t` has the shape `[batchSize, numHeads, S, S]` * `softmax` * `bmm`: `softmax(Q x K^t) x V` where V has the shape `[batchSize, numHeads, S, d]` `S` denotes the sequence length. Depending on the model, it can be as large as `O(10^3) - O(10^4)`. `d` is the size per head in multi-head attention. It’s usually a power of 2 like 64, 128, etc, and smaller than `S`. A limitation of the classic implementation is that it materializes an intermediate matrix of shape `[batchSize, numHeads, S, S]`. This introduces `O(S^2)` memory allocation and traffic. --- ## Autoregression Autoregression is a process by which an AI model iteratively predicts future values based on previous values in a sequence, using its own output as input to itself. Because each prediction depends on prior context, the process is sequential, which limits parallelization. Autoregression is a standard procedure in [transformer](transformer.mdx) models such as large language models (LLMs) and other models that perform time-series forecasting. This autoregressive process explains why AI chat bots like ChatGPT and Gemini stream the output one word at a time—although they sometimes run so fast that they appear to produce more than one word at a time. --- ## Batching Batching is the process of combining multiple inference requests into a single forward pass through the model, thus executing multiple requests simultaneously and improving computational efficiency. To account for requests with varying sequence lengths, it's common to add techniques such as [padding](padding-tokens.mdx) (to standardize lengths) or [ragged tensors](ragged-tensors.mdx) (to handle variable lengths directly). Batch sizes can be either static or dynamic. Whereas static batching uses a fixed batch size and thus waits until the system receives a specific number of inference requests before sending them into the model, dynamic batching uses a flexible batch size. For example, dynamic batching may send a batch of requests to the model as soon as the batch either reaches a certain number of requests (batch size limit) or it reaches a timeout threshold. Dynamic batching can get a lot more complicated than that with additional tricks that keep GPUs busy instead of waiting for one batch to finish before starting another. One such strategy for large language models (LLMs) is [continuous batching](continuous-batching.mdx). --- ## Context encoding Context encoding (also known as "prefill") is the first phase in a [transformer model](transformer.mdx) that converts input data into a cached numerical representation ([KV cache](kv-cache.mdx)) and predicts the first token. It occurs after the input has already been [tokenized](tokenization.mdx) (preprocessed). Context encoding is then followed by the [autoregressive](autoregression.mdx) token generation phase, which produces one token at a time. If it weren't for the KV cache built during context encoding, the model would have to recalculate the [self-attention](self-attention.mdx) score for each token in the original input, every time it starts to predict a new token. Context encoding is usually the most computationally expensive phase in an LLM, because it must calculate attention scores for every token in the input sequence. Although this process may be parallelized across thousands of GPU threads (because each token can be processed separately), it is still a significant latency factor for time-to-first-token (TTFT). The model can usually produce subsequent tokens much faster than the first one because each round of token generation needs to calculate an attention score for only one token (the new one). --- ## Continuous batching Continuous batching is a [batching](batching.mdx) technique that can continuously dispatch inference requests to the GPU for [token generation](token-generation.mdx) and dramatically improve GPU utilization. Continuous batching can start executing a new batch even before the previous batch finishes its pass through the model, because this batching technique schedules new processing at the "token level." That is, because large language models (LLMs) generate responses one token at a time, there is a repeated cycle during inference (the token generation phase) in which a new batch can jump in to utilize the GPU, even before a previous batch finishes its pass through the model. That's what it means to operate at the "token level"—the batch scheduler focuses on keeping the GPU busy with token generation at all times, instead of waiting for the previous batch to finish its complete forward pass. This is sometimes called "in-flight batching" in cases where context encoding and token generation requests are combined into the same batch. --- ## Embedding An embedding (also known as a "vector embedding") is a numerical representation of information in a high-dimensional vector space. For example, a token embedding (or word embedding) encodes the meaning of words for use in large language models (LLMs). Because artificial neural networks (AI models) are a sequence of mathematical operations, they require numerical structures as input. Vector embeddings are numerical structures that provide a way to express a wide range of complex concepts. They can be used to capture information about all sorts of things, including words, groups of words, sounds, images, and more. For example, [tokenizing](tokenization.mdx) a word like "bank" into a simple number can't encode the different meanings in "bank loan" and "river bank." By converting the token into a high-dimensional vector, we can encode (or "embed") a variety of word meanings that help the model understand word relationships using a notion of closeness along various vector dimensions (expressed through [euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)). In this way, when a model encounters the embedding for the word "bank," it can recognize the relationship it has with nearby words such as "loan" or "river," based on the closeness they each have to each other on different vector dimensions (perhaps a "finance" dimension vs a "geography" dimension that were learned during training). Although word embeddings are a type of static embedding that encode the meaning of individual words as input to an LLM, an LLM also builds its own embeddings that are hidden inside the model. For example, as an LLM tries to understand the relationship between each word from an input sequence, it compresses more information into each token embedding based on the attention scores computed in the [self-attention layer](self-attention.mdx). :::note Embedding models Whereas the token embeddings described above use a vector space to represent the meaning of individual tokens, the output from an embedding model uses a vector space to represent the meaning of the input data (a document) as a whole. In this way, an embedding model allows you to programmatically search and compare different documents by analyzing their corresponding embeddings, which can reveal nuanced meaning and semantics far beyond what a pure text comparison can achieve. ::: --- ## Flash attention Flash attention is an optimization technique to compute attention blocks in [transformer](transformer.mdx) models. Traditional [attention](attention.mdx) requires storing large intermediate activation tensors, leading to high memory overhead that slows execution because it requires frequent memory transfers between high-bandwidth memory (HBM) and faster SRAM on the GPU. Flash attention improves performance and reduces the memory footprint for attention layers. It reorders computations with techniques such as tiling to compute attention scores in blocks, and it keeps only small chunks of activations in the faster on-chip SRAM. This allows the model to process much longer sequences without running into memory limitations. By improving the efficiency of attention layers, flash attention enables LLMs to handle much longer contexts, improving their ability to understand and generate complex text. It's particularly beneficial for: * Large language models with long context windows * Vision transformers processing high-resolution images * Multi-modal models with large attention matrices * Fine-tuning large models on limited GPU memory ## Implementation details Flash attention optimizes the classic [attention](attention.mdx) mechanism by: 1. **Tiling the computation**: Breaking the `Q`, `K`, and `V` matrices into smaller blocks that fit in GPU shared memory, which is much faster than global memory. 2. **Fusing operations**: Combining softmax normalization with matrix multiplication for each tile into a single kernel. These help maximize the locality and reduce DRAM (global memory) traffic.
To see an implementation of [FlashAttention-2](https://arxiv.org/abs/2307.08691) as a fused operation, see [`fused_attention.mojo` on GitHub](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo). --- ## AI terms import MDXListing from '@site/src/components/Listing/MDXListing'; export const terms = [ '*.mdx' ] --- ## KV cache KV (key-value) cache is a memory structure used in [transformer](transformer.mdx) models to store key-value tensors output from [self-attention](self-attention.mdx) layers. The KV cache speeds up inference for transformer models such as large language models (LLMs) by avoiding the need to recompute the self-attention scores for all previous tokens in a sequence. For example, suppose an LLM is trying to complete the sentence, "The quick brown fox..." After the model predicts "jumps" and then begins to predict the next token, the model must know the attention score for every token in the sequence so far (including the one it just predicted). That is, for each step in the [autoregression](autoregression.mdx) cycle, it must process the entire sequence thus far: 1. "The quick brown fox..." 2. "The quick brown fox jumps..." 3. "The quick brown fox jumps over..." And so on. By storing the already-calculated attention scores for previous tokens in KV cache, the model simply reads the KV cache at each step, instead of recomputing those scores all over again. Once the model predicts the next token and calculates its self-attention, it adds it to the KV cache. As the sequence length grows during inference (as more words are generated), the KV cache becomes the dominant factor in a model's memory usage. The sequence length is always limited by the model's total context window length, which varies across models and can usually be configured. --- ## Padding tokens Padding tokens are extra tokens (usually zeros or special tokens) that are added to the input for a model so that the input matches the model's fixed input length or to ensure that all sequences in a [batch](batching.mdx) have the same length. In [transformer](transformer.mdx) models, padding tokens have been mostly replaced with [ragged tensors](ragged-tensors.mdx). --- ## PagedAttention PagedAttention is a memory management technique designed to improve GPU memory utilization during large language model (LLM) serving. Inspired by classical virtual memory and paging methods used in operating systems, PagedAttention divides the [KV cache](kv-cache.mdx) into fixed-size blocks, which are not necessarily stored contiguously in memory. This approach enables more efficient handling of dynamic states in LLMs, allowing the model to manage large context sizes while optimizing memory usage, as described in the 2023 paper [Efficient Memory Management for Large Language Model Serving with PagedAttention](https://arxiv.org/abs/2309.06180) (Kwon, et al., 2023). Also written as "paged attention." --- ## Prefill Prefill is the first phase of an AI model's forward pass in which the model processes the input and initializes a cache to accelerate predictions. Different model architectures may have their own version of a prefill, but it's primarily associated with large language models (LLMs), in which case it's also called [context encoding](context-encoding.mdx). --- ## Ragged tensors Ragged tensors is a method for batching multiple requests with differing sequence lengths without the need for [padding tokens](padding-tokens.mdx). Ragged tensors allow sequences of variable lengths to be processed together efficiently by storing them in a compact, non-uniform format. Also sometimes referred to as "packed tensors." --- ## Self-attention Self-attention is a mechanism in a [transformer](transformer.mdx) model that calculates the importance of different tokens (such as words) in a sequence, relative to each other. Each token is said to "attend to" all other tokens in the sequence by assigning an "attention score" to each one. In a large language model (LLM), self-attention allows the model to build an understanding of the whole text by evaluating how each word is relevant to all other words in the text, no matter how far they are from each other. The attention scores are computed using query, key, and value (QKV) vectors that pertain to each token: - The **query** is a vector that expresses what information a token is *looking for* among all the other tokens (like a search query). - The **key** is a vector that describes the information a token *offers* to other tokens (like an answer to a token's query). - The **value** is a vector that provides the **contextually-relevant information** about this token. After calculating attention scores by comparing the **query** and **key** vectors between tokens, self-attention uses the scores to apply weighted information from each token's **value** into a new [embedding](embedding.mdx) for each token. Thus, self-attention outputs a new token embedding for each token that carries information about its relationship with the other tokens in the sequence. The model also saves the calculated keys and values into the [KV cache](kv-cache.mdx) to avoid redundant recompute for the same tokens during the next [autoregression](autoregression.mdx) cycle. --- ## Tokenization Tokenization is the process of dividing the input for an AI model into discrete units that have numerical IDs called tokens. Depending on what the input is (such as text, audio, or an image) the tokens might be based on different words or subwords in text, or different slices/blocks of pixels in images. For example, consider the sentence, "The cat sat on the mat." A word-level tokenization might split this sentence into the following words: "The," "cat," "sat," "on," "the," "mat." Then it replaces each word with a token (a number). The token "vocabulary"—the mapping of words to numbers—is predetermined and may vary from model to model. But tokenizers in large language models (LLMs) are much more sophisticated than that. Among other things, they also tokenize punctuations (or combinations of words and punctuations) and break words into subwords that allow them to tokenize words they've never seen before. Because LLMs are trained on these tokens, they don't actually understand words and letters the way we do. They can only recognize and generate information based on the token vocabulary that they were trained upon. (Popular LLMs have a token vocabulary with over 100,000 tokens.) --- ## Transformer A transformer is a neural network architecture designed to perform complex tasks with sequential data (such as text, speech, and images) in a manner that can be efficiently parallelized on GPUs or other accelerator hardware. This makes them highly effective for natural language processing and other generative AI (GenAI) applications. The transformer model architecture was first introduced in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani, et al., 2017). This design emphasizes the use of [self-attention](self-attention.mdx) mechanisms instead of recurrent structures like recurrent neural networks (RNNs) or long short-term memory networks (LSTMs), which is what allows for the processing to be parallelized across separate compute cores instead of requiring the model to generate predictions synchronously. This design is currently the foundation for all major large language models (LLMs) such as GPT, Llama, Gemini, DeepSeek, and more. --- ## Block index In GPU programming, a block index uniquely identifies a subset of [threads](thread) that execute a [kernel](kernel.mdx) function on the GPU. Threads are grouped into units called [blocks](thread-block.mdx), and multiple blocks together form a larger structure known as a [grid](grid.mdx). Each block within the grid is assigned a unique block index, which can be represented across one, two, or three dimensions. This allows for flexible organization of threads to match the structure of the problem being solved. Within each block, individual threads have their own [thread index](thread-index.mdx), which, together with the block index, determines which part of the problem each thread should work on. This hierarchical structure of grids, blocks, and threads enables efficient workload distribution across the many processing cores of the GPU, maximizing parallel performance. Because a programmer can arrange thread blocks within a grid across one, two, or three dimensions, a block index is a 3-element vector of x, y, and z coordinates. For 2-dimensional arrangements, the z coordinate of all block indices is 0, and for 1-dimensional arrangements, both the y and z coordinates of all block indices are 0. --- ## Grid A grid is the top-level organizational structure of the threads executing a [kernel](kernel.mdx) function on a GPU. A grid consists of multiple [thread blocks](thread-block.mdx) (also known as *workgroups* on AMD GPUs), which are further divided into individual [threads](thread.mdx) (or *work units* on AMD GPUs) that execute the kernel function concurrently. The division of a grid into thread blocks serves multiple crucial purposes: - First, it breaks down the overall workload—managed by the grid—into smaller, more manageable portions that can be processed independently. This division allows for better resource utilization and scheduling flexibility across multiple [streaming multiprocessors](streaming-multiprocessor.mdx) (SMs) in the GPU (or *compute units* on AMD GPUs). - Second, thread blocks provide a scope for threads to collaborate through shared memory and synchronization primitives, enabling efficient parallel algorithms and data sharing patterns. - Finally, thread blocks help with scalability by allowing the same program to run efficiently across different GPU architectures, as the hardware can automatically distribute blocks based on available resources. The programmer specifies the number of thread blocks in a grid and how they are arranged across one, two, or three dimensions. Typically, the programmer determines the dimensions of the grid based on the dimensionality of the data to process. For example, a programmer might choose a 1-dimensional grid for processing large vectors, a 2-dimensional grid for processing matrices, and a 3-dimensional grid for processing the frames of a video. Each block within the grid is assigned a unique [block index](block-index.mdx) that determines its position within the grid. Similarly, the programmer also specifies the number of threads per thread block and how they are arranged across one, two, or three dimensions. Each thread within a block is assigned a unique [thread index](thread-index.mdx) that determines its position within the block. The combination of block index and thread index uniquely identify the position of a thread within the overall grid. --- ## GPU terms import MDXListing from '@site/src/components/Listing/MDXListing'; export const terms = [ '*.mdx' ] --- ## Kernel A kernel is a function that runs on a GPU, executing computations in parallel across a large number of [threads](thread.mdx). Kernels are a fundamental part of general-purpose GPU (GPGPU) programming and are designed to process large datasets efficiently by performing the same operation simultaneously on multiple data elements. --- ## GPU memory GPU memory consists of both on-chip memory and external dynamic random-access memory (DRAM), often referred to as *device memory* (in contrast to the *host memory* used by the CPU). On-chip memory includes: - A register file for each [streaming multiprocessor](streaming-multiprocessor.mdx) (SM), containing the [registers](register.mdx) used by threads executing on the SMs cores - An L1 cache for each SM to cache reads from global memory - Shared memory for each SM, containing data explicitly shared between the threads of a given [thread block](thread-block.mdx) executing on the SM - A read-only constant cache for each SM, which caches data read from the constant memory space in global memory - An L2 cache shared by all SMs that is used to cache accesses to local or global memory, including temporary register spills Device memory includes: - Global memory, which contains data accessible to all threads - Constant memory, which contains data explicitly identified as read-only by the programmer, and which is accessible to all threads - Local memory, which contains data private to an individual thread, such as statically allocated arrays, spilled registers, and other elements of the thread's call stack Data in global memory persists until explicitly freed, even across [kernel](kernel.mdx) functions. This means that one kernel can write data to global memory and then a subsequent kernel can read that data. --- ## Occupancy In GPU programming, occupancy is a measure of the efficiency of the GPU's compute resources. It is defined as the ratio of the number of active [warps](warp.mdx) to the maximum number of warps that can be active on a given [streaming multiprocessor](streaming-multiprocessor.mdx) (SM) at any one time. Higher occupancy can improve parallel execution and hide memory latency, but increasing occupancy does not always boost performance, as factors like memory bandwidth and instruction dependencies may create bottlenecks. The optimal occupancy level depends on the workload and GPU architecture. --- ## Register A GPU register is the fastest form of storage within a [streaming multiprocessor](streaming-multiprocessor.mdx) (SM). Registers store integer and floating point values used frequently by a [thread](thread.mdx), reducing reliance on slower [memory](memory.mdx) types (shared, global, or local memory). Registers are located within an SM in what is referred to as a *register file*. The number of registers depends on the GPU architecture, but modern GPUs support thousands of registers per SM. For each thread that it executes, the SM allocates a set of registers for the private use of that thread. The registers are associated with that thread throughout its lifetime, even if the thread is not currently executing on the SM's cores (for example, if it is blocked waiting for data from memory). A thread can't access registers assigned to a different thread, preventing data conflicts between threads. If the execution of a [kernel](kernel.mdx) function by a thread requires more registers than available, the compiler arranges to spill some register data to the thread's local [memory](memory.mdx). Because local memory access is slower than register access, programmers should try to design their kernels to avoid or limit the amount of spill. --- ## Streaming multiprocessor The basic building block of a GPU is called a *streaming multiprocessor* (SM) on NVIDIA GPUs or a *compute unit* (CU) on AMD GPUs (they're the same idea and we'll refer to them both as SM). SMs sit between the high-level GPU control logic and the individual execution units, acting as self-contained processing factories that can operate independently and in parallel. Multiple SMs are arranged on a single GPU chip, with each SM capable of handling multiple workloads simultaneously. The GPU's global scheduler assigns work to individual SMs, while the memory controller manages data flow between the SMs and various [memory](memory.mdx) hierarchies (global memory, L2 cache, etc.). The number of SMs in a GPU can vary significantly based on the model and intended use case, from a handful in entry-level GPUs to dozens or even hundreds in high-end professional cards. This scalable architecture enables GPUs to maintain excellent performance across different workload sizes and types. Each SM contains several essential components: - **CUDA Cores (NVIDIA)/Stream Processors (AMD):** These are the basic arithmetic logic units (ALUs) that perform integer and floating-point calculations. A single SM can contain dozens or hundreds of these cores. - **Tensor Cores (NVIDIA)/Matrix Cores (AMD):** Specialized units optimized for matrix multiplication and convolution operations. - **Special Function Units (SFUs):** Handle complex mathematical operations like trigonometry, square roots, and exponential functions. - **[Register](register.mdx) Files:** Ultra-fast storage that holds intermediate results and thread-specific data. Modern SMs can have hundreds of kilobytes of register space shared among active [threads](thread.mdx). - **Shared Memory/L1 Cache:** A programmable, low-latency memory space that enables data sharing between threads. This memory is typically configurable between shared memory and L1 cache functions. - **Load/Store Units:** Manage data movement between different memory spaces, handling memory access requests from threads. --- ## Thread block In GPU programming, a thread block (also known as *workgroup* on AMD GPUs) is a subset of threads within a [grid](grid.mdx), which is the top-level organizational structure of the [threads](thread.mdx) executing a [kernel](kernel.mdx) function. As the primary building block for workload distribution, thread blocks serve multiple crucial purposes: - First, they break down the overall workload — managed by the grid — of a kernel function into smaller, more manageable portions that can be processed independently. This division allows for better resource utilization and scheduling flexibility across multiple [streaming multiprocessors](streaming-multiprocessor.mdx) (SMs) in the GPU. - Second, thread blocks provide a scope for threads to collaborate through shared memory and synchronization primitives, enabling efficient parallel algorithms and data sharing patterns. - Finally, thread blocks help with scalability by allowing the same program to run efficiently across different GPU architectures, as the hardware can automatically distribute blocks based on available resources. The programmer specifies the number of thread blocks in a grid and how they are arranged across one, two, or three dimensions. Each block within the grid is assigned a unique [block index](block-index.mdx) that determines its position within the grid. Similarly, the programmer also specifies the number of threads per thread block and how they are arranged across one, two, or three dimensions. Each thread within a block is assigned a unique [thread index](thread-index.mdx) that determines its position within the block. The GPU assigns each thread block within the grid to a streaming multiprocessor (SM) for execution. The SM groups the threads within a block into fixed-size subsets called [warps](warp.mdx), consisting of either 32 or 64 threads each depending on the particular GPU architecture. The SM's warp scheduler manages the execution of warps on the SM's cores. Threads within a block can share data through [shared memory](memory.mdx) and synchronize using built-in mechanisms, but they cannot directly communicate with threads in other blocks. --- ## Thread index In GPU programming, a thread index uniquely identifies the position of a [thread](thread.mdx) within a particular [thread block](thread-block.mdx) executing a [kernel](kernel.mdx) function on the GPU. A thread block is a subset of threads in a [grid](grid.mdx), which is the top-level organizational structure of the threads executing a kernel function. Each block within the grid is also assigned a unique block index, which identifies the block's position within the grid. The combination of block index and thread index uniquely identifies the thread's overall position within the grid, and is used to determine which part of the problem each thread should work on. Because a programmer can arrange threads within a thread block across one, two, or three dimensions, a thread index is a 3-element vector of x, y, and z coordinates. For 2-dimensional arrangements, the z coordinate of all thread indices is 0, and for 1-dimensional arrangements, both the y and z coordinates of all thread indices are 0. --- ## Thread In GPU programming, a thread (also known as a *work unit* on AMD GPUs) is the smallest unit of execution within a [kernel](kernel.mdx) function. Threads are grouped into [thread blocks](thread-block.mdx) (or *workgroups* on AMD GPUs), which are further organized into a [grid](grid.mdx). The programmer specifies the number of thread blocks in a grid and how they are arranged across one, two, or three dimensions. Each block within the grid is assigned a unique [block index](block-index.mdx) that determines its position within the grid. Similarly, the programmer also specifies the number of threads per thread block and how they are arranged across one, two, or three dimensions. Each thread within a block is assigned a unique [thread index](thread-index.mdx) that determines its position within the block. The GPU assigns each thread block within the grid to a [streaming multiprocessor](streaming-multiprocessor.mdx) (SM) for execution. The SM groups the threads within a block into fixed-size subsets called [warps](warp.mdx), consisting of either 32 or 64 threads each depending on the particular GPU architecture. The SM's warp scheduler manages the execution of warps on the SM's cores. The SM allocates a set of [registers](register.mdx) for each thread to store and process values private to that thread. The registers are associated with that thread throughout its lifetime, even if the thread is not currently executing on the SM's cores (for example, if it is blocked waiting for data from memory). Each thread also has access to [local memory](memory.mdx) to store statically allocated arrays, spilled registers, and other elements of the thread's call stack. Threads within a block can share data through shared memory and synchronize using built-in mechanisms, but they cannot directly communicate with threads in other blocks. --- ## Warp In GPU programming, a warp (also known as a *wavefront* on AMD GPUs) is a subset of [threads](thread.mdx) from a [thread block](thread-block.mdx) that execute together in lockstep. When a GPU assigns a thread block to execute on a [streaming multiprocessor](streaming-multiprocessor.mdx) (SM), the SM divides the thread block into warps of 32 or 64 threads, with the exact size depending on the GPU architecture. If a thread block contains a number of threads not evenly divisible by the warp size, the SM creates a partially filled final warp that still consumes the full warp's resources. For example, if a thread block has 100 threads and the warp size is 32, the SM creates: - 3 full warps of 32 threads each (96 threads total) - 1 partial warp with only 4 active threads but still occupying a full warp's worth of resources (32 thread slots) The SM effectively disables the unused thread slots in partial warps, but these slots still consume hardware resources. For this reason, developers generally should make thread block sizes a multiple of the warp size to optimize resource usage. Each thread in a warp executes the same instruction at the same time on different data, following the single instruction, multiple threads (SIMT) execution model. If threads within a warp take different execution paths (called *warp divergence*), the warp serially executes each branch path taken, disabling threads that are not on that path. This means that optimal performance is achieved when all threads in a warp follow the same execution path. An SM can actively manage multiple warps from different thread blocks simultaneously, helping keep execution units busy. For example, the warp scheduler can quickly switch to another ready warp if the current warp's threads must wait for memory access. Warps deliver several key performance advantages: - The hardware needs to manage only warps instead of individual threads, reducing scheduling overhead - Threads in a warp can access contiguous memory locations efficiently through memory coalescing - The hardware automatically synchronizes threads within a warp, eliminating the need for explicit synchronization - The warp scheduler can hide memory latency by switching between warps, maximizing compute resource utilization --- ## Glossary import MDXListing from '@site/src/components/Listing/MDXListing'; Explanations for some terms and concepts you'll encounter in the Modular docs. ## GPU terms export const gpuTerms = [ 'gpu/*.mdx' ] ## AI terms export const aiTerms = [ 'ai/*.mdx' ] --- ## Modular Documentation import Homepage, { GetStartedButton } from "@site/src/components/Homepage"; import CodeNote from "@site/src/components/Homepage/CodeNote"; import { ArrowTransfer } from "@site/src/shared/Svgs/ArrowTransfer"; import { ArrowCloud } from "@site/src/shared/Svgs/ArrowCloud"; import { DesktopCode } from "@site/src/shared/Svgs/DesktopCode"; import { AIChip } from "@site/src/shared/Svgs/AIChip"; import { RecipesIcon } from "@site/src/shared/Svgs/RecipesIcon"; import { OpenBook } from "@site/src/shared/Svgs/OpenBook"; import { PuzzleIcon } from "@site/src/shared/Svgs/PuzzleIcon"; ## Modular Documentation The Modular Platform accelerates AI inference and abstracts hardware complexity. Using our Docker container, you can deploy a GenAI model from Hugging Face with an OpenAI-compatible endpoint on a wide range of hardware. And if you need to customize the model or tune a GPU kernel, Modular provides a depth of model extensibility and GPU programmability that you won't find anywhere else. ```python title="python" from openai import OpenAI client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="EMPTY") completion = client.chat.completions.create( model="google/gemma-3-27b-it", messages=[ {"role": "user", "content": "Who won the world series in 2020?"} ], ) print(completion.choices[0].message.content) ``` export const sectionCards = [ { title: "Serving", description: "Modular’s serving library is compatible with OpenAI APIs, so you can own your endpoint with minimal client-side code changes.", to: "/max/container", icon: , }, { title: "Deploying", description: "Deploy your GenAI models to the cloud and scale your deployments across heterogenous GPU clusters.", to: "/mammoth/", icon: , }, { title: "Developing", description: "The Modular platform provides full extensibility, so you can write custom ops, hardware-agnostic GPU kernels, and more.", to: "/max/develop/", icon: , }, { title: "Programming with Mojo🔥", description: "Mojo is a Python-style programming language that allows you to write code for both CPUs and GPUs. ", to: "/mojo/manual/", icon: , }, ]; export const learningToolCards = [ { title: "Agentic Cookbook", description: "Turn-key applications that use GenAI models with the Modular platform.", href: "https://modul.ar/cookbook", icon: , }, { title: "GPU Puzzles", description: "A hands-on guide to mastering GPU programming with Mojo.", href: "https://builds.modular.com/puzzles", icon: , }, { title: "Build an LLM with MAX", description: "Learn to build an LLM from scratch with MAX.", href: "https://llm.modular.com/", icon: , }, ]; --- ## Disaggregated inference import ContactSection from '@site/src/components/ContactSection'; Disaggregated inference is a serving architecture pattern designed for large language models (LLMs), particularly decoder-only transformer models like those in the LLaMA or GPT model families. In decoder-only transformers, the process of generating model output is divided into two distinct phases: prefill and decode. With disaggregated inference, these phases are executed on different hardware resources. You might see this technique referred to by several names, including disaggregated inference, disaggregated prefill, or disaggregated serving. All of these describe the same core idea: separating the model's inference phases and providing each phase with dedicated resources optimized to improve performance and scalability. :::note Mammoth is the technology behind advanced features like disaggregated inference, routing, and scaling in Modular's Dedicated Endpoint and Enterprise [editions](https://www.modular.com/pricing). [Get in touch](https://www.modular.com/request-demo) to learn how Mammoth enables more efficient, large-scale model serving. ::: ## When to use disaggregated inference Disaggregated inference is particularly valuable if your priority is minimizing latency. Since the prefill stage is compute-intensive and the decode stage is more memory-bound, isolating the two stages and allocating them to different GPUs or GPU nodes reduces resource contention and helps achieve both faster time-to-first-token and smoother token streaming. Because disaggregated inference gives you separate levers to manage the prefill and decode phases independently, it is especially effective for improving tail latency, such as P95, which measures how long it takes to complete 95% of requests. By optimizing tail latency, you reduce delays for the slowest requests and can improve overall responsiveness. Disaggregation itself doesn't directly increase throughput, but it enables more granular parallelism strategies and resource allocation, which can increase processing capacity. This flexibility allows you to optimize each phase appropriately and scale prefill and decode nodes independently as needed, improving GPU utilization and overall efficiency without over-provisioning capacity just to handle peak workloads. Additionally, disaggregated inference offers flexibility in heterogeneous or resource-constrained environments. You can match each phase with hardware that suits its specific demands. ## How disaggregated inference works LLM inference involves two distinct phases known as prefill and decode, each with unique performance characteristics that affect how systems should allocate and optimize resources.
A simplified illustration of the separate prefill and decode nodes used in a disaggregated inference serving architecture.
Prefill, also known as context encoding, is the initial phase where the model processes the entire input prompt. During this phase, the model performs a full forward pass to initialize its key-value (KV) cache and predict the first output token. This cache stores the intermediate attention states necessary for generating subsequent tokens. The prefill phase is compute-intensive, especially in the case of long user prompts, as it involves large-scale matrix operations that demand high floating-point throughput. The metric associated with this phase is often referred to as Time-to-First-Token (TTFT), indicating the duration from receiving the input prompt to producing the first output token. Following prefill, the model enters the decode phase, or token generation. In this phase, the model generates output tokens one at a time, using the KV cache initialized during prefill. By leveraging this cache, the model can quickly access previously computed information without reprocessing the full input each time. As a result, the decoding phase is less compute-intensive per token but becomes memory-bound, relying heavily on efficient access to the cached data. The key performance metric here is Inter-Token Latency (ITL), which measures the time taken to generate each subsequent token after the first. Disaggregated inference involves separating these two phases onto different GPUs or GPU nodes. By doing so, each phase can be optimized independently. Prefill workloads can be routed to hardware with high compute throughput to handle the intensive matrix operations required to process long input prompts. Meanwhile, decode workloads can be assigned to hardware with fast memory access, which are better suited for the sequential, cache-dependent nature of token generation. This separation reduces contention between compute-bound and memory-bound tasks, improves resource utilization, and allows for more scalable and predictable inference performance. ## Become a design partner If you're exploring disaggregated inference for your deployments, start by analyzing your workload to spot any imbalances between prompt processing and token generation. Check whether your GPUs are underutilized during either phase. If you're encountering these challenges, feel free to reach out to talk to an AI expert. --- ## Scale your GenAI deployments import ContactSection from '@site/src/components/ContactSection'; The [Modular Platform](/max/intro) provides dedicated endpoints and enterprise-grade scaling for inference workloads. This scaling logic is powered by Mammoth, a Kubernetes-native distributed AI serving tool that makes it easier to run and manage LLMs at scale using MAX as a backend for optimal model performance. It's designed to maximize hardware efficiency with minimal configuration, even when running multiple models across thousands of nodes.
Figure 1. A simplified diagram of how the Modular Platform scales your GenAI deployment.
The Mammoth control plane automatically selects the best available hardware to meet performance targets when deploying a model and supports both manual and automatic scaling. Mammoth's built-in orchestrator intelligently routes traffic, taking into account hardware load, GPU memory, and caching states. You can deploy and serve multiple models simultaneously across different hardware types or versions without complex setup or duplication of infrastructure. :::note Mammoth powers advanced routing and scaling capabilities behind the scenes for Modular's Dedicated Endpoint and Enterprise [editions](https://www.modular.com/pricing). [Get in touch](https://www.modular.com/request-demo) to learn more about how Mammoth can support your workloads at scale. ::: ## Access to Mammoth If you need to serve one or more LLMs at scale with high performance and minimal operational overhead, you can do so with Modular's Dedicated Endpoint or Enterprise [editions](https://www.modular.com/pricing), which use Mammoth to power routing and scaling capabilities. Mammoth makes a difference when: - You're running inference across heterogeneous GPU clusters (NVIDIA and AMD) and need optimized, vendor-agnostic orchestration. - You want a self-hosted, low-configuration deployment experience that works out of the box, regardless of hardware or cloud provider. - You need to dynamically scale workloads based on traffic and resource availability, with fine-grained control over model placement and scheduling. - You're managing fleets of models and want a unified serving layer without duplicating infrastructure. - You're working in a Kubernetes environment and want native integration that's easy to operate and extend. - You want to optimize total cost of ownership with cluster-level efficiency features like disaggregated inference and KV cache-aware routing. Additionally, because Mammoth is built on the MAX framework, you can use its APIs and tools to customize and optimize every layer of the stack, from high-level orchestration down to GPU kernels written in Mojo. ## How Mammoth works Mammoth consists of a lightweight control plane, an intelligent [orchestrator](/mammoth/orchestrator), and advanced optimizations such as [disaggregated inference](/mammoth/disaggregated-inference), all working together to efficiently deploy and run models across diverse hardware environments.
Figure 2. An overview of the Mammoth components, including the control plane, orchestrator, and disaggregated inference on separate prefill and decode nodes.
At the heart of Mammoth is its control plane, which takes care of setting up, running, and scaling models automatically. Just provide the model ID (such as `modularai/Llama-3.1-8B-Instruct`) or a path to the model on an external storage provider like S3, and the control plane handles the rest. You can interact with the control plane for: - Model deployment: Launch models with a single command. - Model management: Modify or delete deployed models. - Multi-model orchestration: Run multiple models efficiently across shared infrastructure. - Scaling: Adjust replicas manually or let Mammoth autoscale intelligently. - Resource allocation: Automatically allocate GPU resources to model deployment. The Mammoth control plane extends the Kubernetes API with [custom resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/) definitions (CRDs) and controls those resources with an [operator](https://kubernetes.io/docs/concepts/extend-kubernetes/operator/). When you create, update, or delete a resource, the control plane provisions infrastructure, deploys or reconfigures models, and cleans up resources as needed. ### Deploy models With Mammoth running behind the scenes, deploying models in Modular's Dedicated Endpoint and Enterprise editions is designed to be simple. You choose the model you want to serve and define your resource requirements, and Mammoth's control plane takes care of the rest. It automatically discovers available NVIDIA or AMD GPUs, schedules the workload across the cluster, and scales as needed. Whether you're serving a single large model or multiple models at once, Mammoth handles orchestration and optimization so you can focus on your application rather than infrastructure. ### Scale deployments The control plane adjusts the deployment to the desired number of replicas and allocates resources accordingly. For production use, intelligent autoscaling is built in and configurable. ### Allocate resources You can fine-tune resource allocation for each deployment. For example, with [disaggregated inference](/mammoth/disaggregated-inference), you can assign separate GPU resources to nodes that handle prefill and decode stages independently. ## Become a design partner Mammoth is currently only available through Modular's early access program where we're actively partnering with select organizations as design partners. Design partners get early access to new features and share feedback to help shape the future of Mammoth. Talk to an AI expert to learn more about how Mammoth can support your use case and help you scale with confidence. --- ## Routing and orchestration import ContactSection from '@site/src/components/ContactSection'; The orchestrator is responsible for distributing incoming inference requests to the appropriate worker node in a cluster. This orchestration layer plays a critical role in performance, load balancing, memory optimization, and user experience. Rather than simply forwarding requests to the next available worker, the orchestrator uses configurable routing strategies to intelligently direct traffic. Each routing strategy has trade-offs, and the ideal strategy depends on the characteristics of your workload. ## How the orchestrator works The orchestrator routes inference requests across distributed workers. The orchestrator receives a prompt from an HTTP server, then analyzes the request to extract information relevant to the specific routing strategy. The orchestrator then selects a worker based on the specified routing algorithm and current cluster state, proxies the request to the relevant worker, and streams the response back to the user. :::note Orchestration with Mammoth is still in preview and some aspects may change as we refine the implementation. Expect ongoing improvements and potential adjustments based on feedback and performance optimizations. :::
An overview of steps taken by the Mammoth orchestrator.
## Routing options You can configure the routing strategy based on your deployment goals. For stateless requests and broad load balancing, the round robin or least request routing options work well. If you're optimizing for cache reuse or continuity in conversation, prefix-aware, sticky sessions, or KV cache-aware routing may offer significant performance gains. We also provide a random routing algorithm for benchmarking or experimental purposes. | Name | Strategy | Use case | |-------------------|--------------------------------------------------------------------------|----------------------------------------------------------------| | KV cache-aware | Routes based on shared tokens or document chunks in the KV cache | Repeated prompts in chatbots, agents, or RAG-style workflows | | Least request | Sends requests to the worker with the fewest active requests | Mixed workloads with variable size or latency requirements | | Prefix-aware | Uses consistent hashing on prompt prefixes to group similar requests | Prompts with shared templates or recurring task descriptions | | Random | Selects a backend worker at random | Benchmarking and exposing latency variability | | Round robin | Distributes requests evenly across all workers in sequential order | Stateless, uniform tasks without caching needs | | Sticky session | Routes requests with the same session ID to the same worker | Session-based chat or apps needing memory and continuity | ### KV cache-aware KV cache-aware routing manages requests based on the contents of the KV cache on each worker. You might use KV cache-aware routing if you're running a retrieval-augmented generation (RAG) system where most queries share common document chunks or similar inputs, but not identical prefixes. KV cache-aware routing is especially useful in the following scenarios: - For high-throughput workloads with many repeating or similar tokens across queries. - When you want to minimize redundant computation across diverse, overlapping queries. ### Least request Least request routing sends new inference requests to the worker currently handling the fewest active requests. This helps balance load dynamically and reduces the chance of overloading any single worker. You might use least request routing when serving a mix of both small and large generation tasks in order to avoid piling multiple large requests on the same node. Least request routing is especially useful in the following situations: - When some workers receive heavier workloads or respond slower. - For variable-length or unpredictable inference tasks. - When you're optimizing for low tail latency. ### Prefix-aware Prefix-aware routing, also known as consistent hashing, examines the prompt prefix in an incoming request and routes it to the worker handling requests with the same prefix. For example, if a support chatbot frequently receives the prefix `{"role": "system", "content": "You are a helpful assistant."}` followed by user-specific questions, prefix-aware routing keeps that common prefix cached on a single node. When a worker becomes saturated with requests for a popular prefix, the orchestrator automatically distributes the load by spilling over to additional workers, maintaining partial cache locality while balancing traffic. Prefix-aware request routing is especially useful in the following situations: - When many users send queries that start with the same instructions or template. - If users frequently issue similar or identical prompts, like a recurring task description or persona. - In multi-turn conversations where session stickiness isn't enabled ### Random Random routing selects a backend worker at random from the pool of available endpoints for each incoming request. Random routing is useful when you want to eliminate routing bias and observe average worker performance under distributed load. It can help identify variability in latency or behavior across nodes without favoring specific ones. Random routing is especially useful for benchmarking use cases. ### Round robin The round robin routing algorithm distributes incoming requests evenly across all available workers in sequential order. Once the orchestrator reaches the last worker in the list, it cycles back to the first. You might use round robin routing if you're running a batch of isolated tasks that don't require any request context or caching. Round robin routing is especially useful in the following situations: - For stateless or homogenous workloads where each request is independent. - For testing environments or basic load distribution. ### Sticky session Sticky session routing sends a user's requests to the same worker node for the duration of their session. A session is identified by checking for a session ID value in the request HTTP header. If this header is not present, the orchestrator falls back to round robin routing. You might use sticky session routing for a chatbot with user interaction, where keeping their requests on the same worker node avoids reloading context repeatedly. Sticky session routing is especially useful in the following situations: - When in-flight session state (ex. conversational memory) is maintained on the server. - For chatbots or streaming applications where continuity is important. - When memory locality is key to performance. ## Become a design partner To get the most out of prefix-aware routing and other advanced strategies, you can explore [prefix caching](/max/serve/prefix-caching) and other serving layer optimizations in MAX. To get started with Mammoth's cluster-based deployments and optimize request routing for your specific use case, you can reach out us and talk to an AI expert. --- ## Common ```c #include "max/c/common.h" ``` **Functions:** ### `M_version()` > const char \*M\_version() Gets the MAX Engine version. * **Returns:** A string containing the semantic version of the MAX Engine. ### `M_newStatus()` > [M\_Status](types.md#_CPPv48M_Status) \*M\_newStatus() Creates a new status object. This is required as an argument for several functions, such as [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175) and [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750). They will update the status object and you can check for errors with [`M_isError()`](#common_8h_1adb7a61f1c8f9c5e7964e8788cd437468) and get the status message with [`M_getError()`](#common_8h_1aa294beac43a0884cef8386e69a6bfc1b). For example: ```c M_Status *status = M_newStatus(); M_RuntimeConfig *runtimeConfig = M_newRuntimeConfig(); M_RuntimeContext *context = M_newRuntimeContext(runtimeConfig, status); if (M_isError(status)) { logError(M_getError(status)); return EXIT_FAILURE; } ``` * **Returns:** A pointer to the new status object. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeStatus()`](#common_8h_1ab5067fd51a5696b3679f7f629d3329c4). ### `M_getError()` > const char \*M\_getError(const [M\_Status](types.md#_CPPv48M_Status) \*status) Gets an error message from the `M_Status` parameter. You should call this only if [`M_isError()`](#common_8h_1adb7a61f1c8f9c5e7964e8788cd437468) is true. * **Parameters:** status – The status object for reporting errors and other messages. * **Returns:** A pointer to a null-terminated string containing the error message. ### `M_isError()` > int M\_isError(const [M\_Status](types.md#_CPPv48M_Status) \*status) Checks if status holds an error value. * **Parameters:** status – The status object for reporting errors and other messages. * **Returns:** `0` if there is no error, `1` otherwise. ### `M_freeStatus()` > void M\_freeStatus([M\_Status](types.md#_CPPv48M_Status) \*status) Deallocates the memory for the status object. No-op if `status` is `NULL`. * **Parameters:** status – The status object for reporting errors and other messages. --- ## Context ```c #include "max/c/context.h" ``` **Functions:** ### `M_newRuntimeConfig()` > [M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*M\_newRuntimeConfig() Creates a new runtime config. This configures runtime details such as the number of threads and log level. By default, the config object’s number of threads will be set to `0`, which is internally used to refer to the number of physical processors in the first socket in the system. You can change this with `M_setNumThreads()`. You need this as an argument for [`M_newRuntimeContext()`](#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). * **Returns:** A pointer to the new runtime config. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeRuntimeConfig()`](#context_8h_1a47f7e22f7f71da9ab5fb3a1886911610). ### `M_freeRuntimeConfig()` > void M\_freeRuntimeConfig([M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config) Deallocates the memory for a runtime config. No-op if `config` is `NULL`. * **Parameters:** config – The runtime config. ### `M_runtimeConfigAddDevice()` > void M\_runtimeConfigAddDevice([M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config, [M\_Device](types.md#_CPPv48M_Device) \*device) Adds a device to be accessible from the runtime. * **Parameters:** * config – The runtime config. * device – The device to add to the runtime config. ### `M_newRuntimeContext()` > [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*M\_newRuntimeContext(const [M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config, [M\_Status](types.md#_CPPv48M_Status) \*status) Creates a runtime context. The context is an application-level object that sets up various resources such as threadpool and allocators during inference. You need this before you can call [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750). It’s expected that there’s only one runtime context active in an inference session at a time. We recommended you create one context and use it throughout your application. For example: ```c M_Status *status = M_newStatus(); M_RuntimeConfig *runtimeConfig = M_newRuntimeConfig(); M_RuntimeContext *context = M_newRuntimeContext(runtimeConfig, status); if (M_isError(status)) { logError(M_getError(status)); return EXIT_FAILURE; } ``` * **Parameters:** * config – The runtime config, from [`M_newRuntimeConfig()`](#context_8h_1a963f1d4eefd812ba8691acf516007cfc). * status – The status object for reporting errors. It is filled with an error message if construction of the runtime context fails. * **Returns:** A pointer to the runtime context object. On success, this is a valid pointer. On failure, this is a `NULL` pointer with an error message in the status. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeRuntimeContext()`](#context_8h_1a2434a11d8d65890c66f6b5516243a730). ### `M_freeRuntimeContext()` > void M\_freeRuntimeContext([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context) Deallocates the memory for a runtime context. No-op if `context` is `NULL`. * **Parameters:** context – The runtime context. ### `M_setDebugPrintOptions()` > void M\_setDebugPrintOptions([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_ResultOutputStyle](types.md#_CPPv419M_ResultOutputStyle) style, unsigned int precision, const char \*directory) Set the options for debugging printing of tensors when executing a model. * **Parameters:** * context – The runtime context. * style – The way the data will be printed. * precision – The floating point print out precision. * directory – The directory to store binary output. ### `M_setMojoDefineBool()` > void M\_setMojoDefineBool([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, bool value) Sets a mojo compile-time define with an boolean value. * **Parameters:** * context – The runtime context. * key – The name of the define. * value – The boolean to set the define to. ### `M_setMojoDefineInt()` > void M\_setMojoDefineInt([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, int value) Sets a mojo compile-time define with an integer value. * **Parameters:** * context – The runtime context. * key – The name of the define. * value – The integer to set the define to. ### `M_setMojoDefineString()` > void M\_setMojoDefineString([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, const char \*value) Sets a mojo compile-time define with an string value. * **Parameters:** * context – The runtime context. * key – The name of the define. * value – The string to set the define to. --- ## C API You can use the following C APIs to run inference with MAX Engine. ## API headers Each of the following pages represents one of the C API header files: * [Common](common.md) * [`M_version()`](common.md#_CPPv49M_versionv) * [`M_newStatus()`](common.md#_CPPv411M_newStatusv) * [`M_getError()`](common.md#_CPPv410M_getErrorPK8M_Status) * [`M_isError()`](common.md#_CPPv49M_isErrorPK8M_Status) * [`M_freeStatus()`](common.md#_CPPv412M_freeStatusP8M_Status) * [Context](context.md) * [`M_newRuntimeConfig()`](context.md#_CPPv418M_newRuntimeConfigv) * [`M_freeRuntimeConfig()`](context.md#_CPPv419M_freeRuntimeConfigP15M_RuntimeConfig) * [`M_runtimeConfigAddDevice()`](context.md#_CPPv424M_runtimeConfigAddDeviceP15M_RuntimeConfigP8M_Device) * [`M_newRuntimeContext()`](context.md#_CPPv419M_newRuntimeContextPK15M_RuntimeConfigP8M_Status) * [`M_freeRuntimeContext()`](context.md#_CPPv420M_freeRuntimeContextP16M_RuntimeContext) * [`M_setDebugPrintOptions()`](context.md#_CPPv422M_setDebugPrintOptionsP16M_RuntimeContext19M_ResultOutputStylejPKc) * [`M_setMojoDefineBool()`](context.md#_CPPv419M_setMojoDefineBoolP16M_RuntimeContextPKcb) * [`M_setMojoDefineInt()`](context.md#_CPPv418M_setMojoDefineIntP16M_RuntimeContextPKci) * [`M_setMojoDefineString()`](context.md#_CPPv421M_setMojoDefineStringP16M_RuntimeContextPKcPKc) * [Model](model.md) * [`M_newCompileConfig()`](model.md#_CPPv418M_newCompileConfigv) * [`M_setModelPath()`](model.md#_CPPv414M_setModelPathP15M_CompileConfigPKc) * [`M_compileModel()`](model.md#_CPPv414M_compileModelPK16M_RuntimeContextPP15M_CompileConfigP8M_Status) * [`M_waitForCompilation()`](model.md#_CPPv420M_waitForCompilationP20M_AsyncCompiledModelP8M_Status) * [`M_compileModelSync()`](model.md#_CPPv418M_compileModelSyncPK16M_RuntimeContextPP15M_CompileConfigP8M_Status) * [`M_initModel()`](model.md#_CPPv411M_initModelPK16M_RuntimeContextPK20M_AsyncCompiledModelPK17M_WeightsRegistryP8M_Status) * [`M_waitForModel()`](model.md#_CPPv414M_waitForModelP12M_AsyncModelP8M_Status) * [`M_executeModelSync()`](model.md#_CPPv418M_executeModelSyncPK16M_RuntimeContextP12M_AsyncModelP16M_AsyncTensorMapP8M_Status) * [`M_freeModel()`](model.md#_CPPv411M_freeModelP12M_AsyncModel) * [`M_freeCompiledModel()`](model.md#_CPPv419M_freeCompiledModelP20M_AsyncCompiledModel) * [`M_freeCompileConfig()`](model.md#_CPPv419M_freeCompileConfigP15M_CompileConfig) * [Tensor](tensor.md) * [`M_newTensorSpec()`](tensor.md#_CPPv415M_newTensorSpecPK7int64_t7int64_t7M_DtypePKcPK8M_Device) * [`M_isDynamicRanked()`](tensor.md#_CPPv417M_isDynamicRankedPK12M_TensorSpec) * [`M_getDimAt()`](tensor.md#_CPPv410M_getDimAtPK12M_TensorSpec6size_t) * [`M_getRank()`](tensor.md#_CPPv49M_getRankPK12M_TensorSpec) * [`M_getDtype()`](tensor.md#_CPPv410M_getDtypePK12M_TensorSpec) * [`M_getName()`](tensor.md#_CPPv49M_getNameP12M_TensorSpec) * [`M_newAsyncTensorMap()`](tensor.md#_CPPv419M_newAsyncTensorMapPK16M_RuntimeContext) * [`M_borrowTensorInto()`](tensor.md#_CPPv418M_borrowTensorIntoP16M_AsyncTensorMapPvPK12M_TensorSpecP8M_Status) * [`M_getTensorByNameFrom()`](tensor.md#_CPPv421M_getTensorByNameFromP16M_AsyncTensorMapPKcP8M_Status) * [`M_getTensorNumElements()`](tensor.md#_CPPv422M_getTensorNumElementsPK13M_AsyncTensor) * [`M_getTensorType()`](tensor.md#_CPPv415M_getTensorTypePK13M_AsyncTensor) * [`M_getTensorData()`](tensor.md#_CPPv415M_getTensorDataPK13M_AsyncTensor) * [`M_getTensorSpec()`](tensor.md#_CPPv415M_getTensorSpecPK13M_AsyncTensor) * [`M_getDeviceTypeFromSpec()`](tensor.md#_CPPv423M_getDeviceTypeFromSpecPK12M_TensorSpec) * [`M_getDeviceIdFromSpec()`](tensor.md#_CPPv421M_getDeviceIdFromSpecPK12M_TensorSpec) * [`M_getTensorDevice()`](tensor.md#_CPPv417M_getTensorDevicePK13M_AsyncTensor) * [`M_copyTensorToDevice()`](tensor.md#_CPPv420M_copyTensorToDeviceP13M_AsyncTensorP8M_DeviceP8M_Status) * [`M_freeTensor()`](tensor.md#_CPPv412M_freeTensorP13M_AsyncTensor) * [`M_freeTensorNameArray()`](tensor.md#_CPPv421M_freeTensorNameArrayP17M_TensorNameArray) * [`M_freeTensorSpec()`](tensor.md#_CPPv416M_freeTensorSpecP12M_TensorSpec) * [`M_freeAsyncTensorMap()`](tensor.md#_CPPv420M_freeAsyncTensorMapP16M_AsyncTensorMap) * [Types](types.md) * [`M_Status`](types.md#_CPPv48M_Status) * [`M_RuntimeConfig`](types.md#_CPPv415M_RuntimeConfig) * [`M_RuntimeContext`](types.md#_CPPv416M_RuntimeContext) * [`M_CompileConfig`](types.md#_CPPv415M_CompileConfig) * [`M_AsyncCompiledModel`](types.md#_CPPv420M_AsyncCompiledModel) * [`M_AsyncModel`](types.md#_CPPv412M_AsyncModel) * [`M_AsyncTensor`](types.md#_CPPv413M_AsyncTensor) * [`M_TensorNameArray`](types.md#_CPPv417M_TensorNameArray) * [`M_TensorSpec`](types.md#_CPPv412M_TensorSpec) * [`M_AsyncTensorMap`](types.md#_CPPv416M_AsyncTensorMap) * [`M_WeightsRegistry`](types.md#_CPPv417M_WeightsRegistry) * [`M_Device`](types.md#_CPPv48M_Device) * [`M_Dtype`](types.md#_CPPv47M_Dtype) * [`M_UNKNOWN`](types.md#_CPPv4N7M_Dtype9M_UNKNOWNE) * [`mIsInteger`](types.md#_CPPv4N7M_Dtype10mIsIntegerE) * [`mIsFloat`](types.md#_CPPv4N7M_Dtype8mIsFloatE) * [`mIsComplex`](types.md#_CPPv4N7M_Dtype10mIsComplexE) * [`mIsSigned`](types.md#_CPPv4N7M_Dtype9mIsSignedE) * [`kIntWidthShift`](types.md#_CPPv4N7M_Dtype14kIntWidthShiftE) * [`M_INT1`](types.md#_CPPv4N7M_Dtype6M_INT1E) * [`M_UINT1`](types.md#_CPPv4N7M_Dtype7M_UINT1E) * [`M_INT2`](types.md#_CPPv4N7M_Dtype6M_INT2E) * [`M_UINT2`](types.md#_CPPv4N7M_Dtype7M_UINT2E) * [`M_INT4`](types.md#_CPPv4N7M_Dtype6M_INT4E) * [`M_UINT4`](types.md#_CPPv4N7M_Dtype7M_UINT4E) * [`M_INT8`](types.md#_CPPv4N7M_Dtype6M_INT8E) * [`M_UINT8`](types.md#_CPPv4N7M_Dtype7M_UINT8E) * [`M_INT16`](types.md#_CPPv4N7M_Dtype7M_INT16E) * [`M_UINT16`](types.md#_CPPv4N7M_Dtype8M_UINT16E) * [`M_INT32`](types.md#_CPPv4N7M_Dtype7M_INT32E) * [`M_UINT32`](types.md#_CPPv4N7M_Dtype8M_UINT32E) * [`M_INT64`](types.md#_CPPv4N7M_Dtype7M_INT64E) * [`M_UINT64`](types.md#_CPPv4N7M_Dtype8M_UINT64E) * [`M_INT128`](types.md#_CPPv4N7M_Dtype8M_INT128E) * [`M_UINT128`](types.md#_CPPv4N7M_Dtype9M_UINT128E) * [`M_FLOAT4_E2M1FN`](types.md#_CPPv4N7M_Dtype15M_FLOAT4_E2M1FNE) * [`M_FLOAT8_E8M0FNU`](types.md#_CPPv4N7M_Dtype16M_FLOAT8_E8M0FNUE) * [`M_FLOAT8_E3M4`](types.md#_CPPv4N7M_Dtype13M_FLOAT8_E3M4E) * [`M_FLOAT8_E4M3FN`](types.md#_CPPv4N7M_Dtype15M_FLOAT8_E4M3FNE) * [`M_FLOAT8_E4M3FNUZ`](types.md#_CPPv4N7M_Dtype17M_FLOAT8_E4M3FNUZE) * [`M_FLOAT8_E5M2`](types.md#_CPPv4N7M_Dtype13M_FLOAT8_E5M2E) * [`M_FLOAT8_E5M2FNUZ`](types.md#_CPPv4N7M_Dtype17M_FLOAT8_E5M2FNUZE) * [`M_FLOAT16`](types.md#_CPPv4N7M_Dtype9M_FLOAT16E) * [`M_BFLOAT16`](types.md#_CPPv4N7M_Dtype10M_BFLOAT16E) * [`M_FLOAT32`](types.md#_CPPv4N7M_Dtype9M_FLOAT32E) * [`M_FLOAT64`](types.md#_CPPv4N7M_Dtype9M_FLOAT64E) * [`M_BOOL`](types.md#_CPPv4N7M_Dtype6M_BOOLE) * [`M_AllocatorType`](types.md#_CPPv415M_AllocatorType) * [`kSystem`](types.md#_CPPv4N15M_AllocatorType7kSystemE) * [`kCaching`](types.md#_CPPv4N15M_AllocatorType8kCachingE) * [`M_ValueType`](types.md#_CPPv411M_ValueType) * [`M_STRING_VALUE`](types.md#_CPPv4N11M_ValueType14M_STRING_VALUEE) * [`M_DOUBLE_VALUE`](types.md#_CPPv4N11M_ValueType14M_DOUBLE_VALUEE) * [`M_LONG_VALUE`](types.md#_CPPv4N11M_ValueType12M_LONG_VALUEE) * [`M_BOOL_VALUE`](types.md#_CPPv4N11M_ValueType12M_BOOL_VALUEE) * [`M_TENSOR_VALUE`](types.md#_CPPv4N11M_ValueType14M_TENSOR_VALUEE) * [`M_LIST_VALUE`](types.md#_CPPv4N11M_ValueType12M_LIST_VALUEE) * [`M_TUPLE_VALUE`](types.md#_CPPv4N11M_ValueType13M_TUPLE_VALUEE) * [`M_DICT_VALUE`](types.md#_CPPv4N11M_ValueType12M_DICT_VALUEE) * [`M_NONE_VALUE`](types.md#_CPPv4N11M_ValueType12M_NONE_VALUEE) * [`M_UNKNOWN_VALUE`](types.md#_CPPv4N11M_ValueType15M_UNKNOWN_VALUEE) * [`M_MOJO_VALUE`](types.md#_CPPv4N11M_ValueType12M_MOJO_VALUEE) * [`M_PYTHON_MOJO_VALUE`](types.md#_CPPv4N11M_ValueType19M_PYTHON_MOJO_VALUEE) * [`M_DeviceType`](types.md#_CPPv412M_DeviceType) * [`M_HOST`](types.md#_CPPv4N12M_DeviceType6M_HOSTE) * [`M_ACCELERATOR`](types.md#_CPPv4N12M_DeviceType13M_ACCELERATORE) * [`M_ResultOutputStyle`](types.md#_CPPv419M_ResultOutputStyle) * [`M_COMPACT`](types.md#_CPPv4N19M_ResultOutputStyle9M_COMPACTE) * [`M_FULL`](types.md#_CPPv4N19M_ResultOutputStyle6M_FULLE) * [`M_BINARY`](types.md#_CPPv4N19M_ResultOutputStyle8M_BINARYE) * [`M_BINARY_MAX_CHECKPOINT`](types.md#_CPPv4N19M_ResultOutputStyle23M_BINARY_MAX_CHECKPOINTE) * [`M_NONE`](types.md#_CPPv4N19M_ResultOutputStyle6M_NONEE) ## Async API usage Our C API allows for compiling and executing models asynchronously. In general, effective use of asynchronous APIs may be difficult, but rewarding for performance. To help with this, we’re going to explain some important concepts and mental models to keep in mind with the API. Our APIs are async-safe unless stated otherwise, typically with a `Sync` in the function identifier name. For example, we have `M_executeModel` and [`M_executeModelSync()`](model.md#_CPPv418M_executeModelSyncPK16M_RuntimeContextP12M_AsyncModelP16M_AsyncTensorMapP8M_Status). ### Types Our API describes the underlying async-holding types with a “value or error” concept. Conceptually, this means that the type is in one of three states: * `Constructed` - the value is not yet there, but there is no error * `Available` - the value is there and ready for use * `Error` - the value is not there and there is an error ### Synchronization points When using async APIs, it is a good idea to be mindful of the synchronization point APIs currently provided below. This is useful for discerning between the `Constructed` and `Available` states mentioned above. After calling the synchronization point, the input will never be in a `Constructed` state: it will always resolve to either being `Available` or `Error`. * [`M_waitForCompilation()`](model.md#_CPPv420M_waitForCompilationP20M_AsyncCompiledModelP8M_Status) * [`M_waitForModel()`](model.md#_CPPv414M_waitForModelP12M_AsyncModelP8M_Status) * `M_waitForTensors` ### Errors Errors surface immediately when using our synchronous APIs. Otherwise, in the case of async APIs, errors will not surface until the next synchronization point. You can query the error message by calling [`M_getError()`](common.md#_CPPv410M_getErrorPK8M_Status). --- ## Model ```c #include "max/c/model.h" ``` **Functions:** ### `M_newCompileConfig()` > [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*M\_newCompileConfig() Creates an object you can use to configure model compilation. You need `M_CompileConfig` as an argument for several functions, including [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7) and [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). * **Returns:** A pointer to a new compilation configuration. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompileConfig()`](#model_8h_1abbf74b13adaf5bc8a0bb4d46c40688d9). This compilation configuration can only be used for a single compilation call. Any subsequent compilations must be passed a new `M_CompileConfig` (created by calling [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665) again). ### `M_setModelPath()` > void M\_setModelPath([M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*compileConfig, const char \*path) Sets the path to a model. You must call this before you call [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). Otherwise, [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750) returns an error in `status`. * **Parameters:** * compileConfig – The compilation configuration for your model, from [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665). * path – The path to your model. The model does not need to exist on the filesystem at this point. This follows the same semantics and expectations as `std::filesystem::path`. ### `M_compileModel()` > [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*M\_compileModel(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*\*compileConfig, [M\_Status](types.md#_CPPv48M_Status) \*status) Compiles a model. This immediately returns an `M_AsyncCompiledModel`, with compilation happening asynchronously. If you need to block to await compilation, you can then call [`M_waitForCompilation()`](#model_8h_1a8040a6488596f863c205d769d92ad013). You must call [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7) before you call this. For example: ```c M_CompileConfig *compileConfig = M_newCompileConfig(); M_setModelPath(compileConfig, modelPath); M_AsyncCompiledModel *compiledModel = M_compileModel(context, &compileConfig, status); if (M_isError(status)) { logError(M_getError(status)); return EXIT_FAILURE; } ``` The `M_AsyncCompiledModel` returned here is not ready for inference yet. You need to then initialize the model with [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a). * **Parameters:** * context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). * compileConfig – Address of compilation configuration for your model created with [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665), and with the model set via [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7). Ownership of configuration is handed over to API. * status – The status used to report errors in the case of failures during model compilation. * **Returns:** A pointer to an `M_AsyncCompiledModel`. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompiledModel()`](#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c). If the config is invalid, it returns a `NULL` pointer. If the model compilation fails, the pointer is `NULL` and the `status` parameter contains an error message. `compileConfig` will be reset to `NULL` after this call irrespective of status and cannot be reused, and any subsequent calls must take a new `M_CompileConfig`. ### `M_waitForCompilation()` > void M\_waitForCompilation([M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*compiledModel, [M\_Status](types.md#_CPPv48M_Status) \*status) Blocks execution until the model is compiled. This waits for the async compiled model to be complete after calling [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). When this function returns, the model is resolved to either a compiled model or an error. * **Parameters:** * compiledModel – The model received from [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). * status – The status used to report errors in the case of failures. ### `M_compileModelSync()` > [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*M\_compileModelSync(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*\*compileConfig, [M\_Status](types.md#_CPPv48M_Status) \*status) Synchronously compiles a model. Unlike [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750), this blocks until model compilation is complete. It returns an `M_AsyncCompiledModel` without needing to call [`M_waitForCompilation()`](#model_8h_1a8040a6488596f863c205d769d92ad013). All other setup and usage is identical to [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). * **Parameters:** * context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). * compileConfig – Address of compilation configuration for your model created with [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665), and with the model set via [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7). Ownership of configuration is handed over to API. * status – The status used to report errors in the case of failures during model compilation. * **Returns:** A pointer to an `M_AsyncCompiledModel`. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompiledModel()`](#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c). If the config is invalid, it returns a `NULL` pointer. If the model compilation fails, the pointer is `NULL` and the `status` parameter contains an error message. `compileConfig` will be reset to `NULL` after this call irrespective of status and cannot be reused, and any subsequent calls must take a new `M_CompileConfig`. ### `M_initModel()` > [M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*M\_initModel(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*compiledModel, const [M\_WeightsRegistry](types.md#_CPPv417M_WeightsRegistry) \*weightsRegistry, [M\_Status](types.md#_CPPv48M_Status) \*status) Sets up a model for execution. You can call this immediately after [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750)—you don’t need to wait for the async compilation. This function also returns immediately with model initialization happening asynchronously. For example: ```c M_AsyncModel *model = M_initModel( context, compiledModel, weightsRegistry, status); if (M_isError(status)) { logError(M_getError(status)); return EXIT_FAILURE; } ``` If you want to block until `M_AsyncModel` is initialized, you can call [`M_waitForModel()`](#model_8h_1a852bec3f80cebb5c06911091d5cab349), but that’s not necessary and you can immediately call [`M_executeModelSync()`](#model_8h_1a2ced4683834a77d0b943a6bc72d846d5). * **Parameters:** * context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). * compiledModel – The compiled model, from [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). * weightsRegistry – A mapping from weights’ names to their data. The weights registry is used to update weights or otherwise pass weights to the model init block at runtime, without recompiling the model graph. If the model doesn’t use the weights registry, it is safe to pass as NULL * status – The status used to report errors in the case of failures. The status contains an error only if the given context or compiled model is invalid. Other errors will not surface until the next synchronization point. * **Returns:** A pointer to an `M_AsyncModel` that holds an async value to a compiled model. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeModel()`](#model_8h_1a4094fa8e414f8b6a6563474f8840d33c). If model initialization fails, the `status` parameter contains an error message. ### `M_waitForModel()` > void M\_waitForModel([M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*model, [M\_Status](types.md#_CPPv48M_Status) \*status) Blocks execution until the model is initialized. This waits for the model setup to finish in [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a). * **Parameters:** * model – The model. * status – The status used to report errors in the case of failures. ### `M_executeModelSync()` > [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*M\_executeModelSync(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*initializedModel, [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*inputs, [M\_Status](types.md#_CPPv48M_Status) \*status) Executes a model synchronously. The inputs and outputs are `M_AsyncTensorMap` objects to allow chaining of inference. This operation is blocking and waits until the output results are ready. * **Parameters:** * context – The runtime context. * initializedModel – The model to execute, from [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a). Although that function is async, you can pass the `M_AsyncModel` here immediately. * inputs – The tensor inputs. * status – The status used to report errors in the case of failures. This includes failures encountered while running the model; there is no need for an explicit synchronization point. * **Returns:** A pointer to an `M_AsyncTensorMap` that holds the output tensors. These tensors are in a resolved state. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeAsyncTensorMap()`](tensor.md#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e). In the case that executing the model fails, the `status` parameter contains an error message. ### `M_freeModel()` > void M\_freeModel([M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*model) Deallocates the memory for the model. No-op if `model` is `NULL`. * **Parameters:** model – The model to deallocate. ### `M_freeCompiledModel()` > void M\_freeCompiledModel([M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*model) Deallocates the memory for the compiled model. No-op if `model` is `NULL`. * **Parameters:** model – The compiled model to deallocate. ### `M_freeCompileConfig()` > void M\_freeCompileConfig([M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*config) Deallocates the memory for the compile config. No-op if `config` is `NULL`. * **Parameters:** config – The compilation configuration to deallocate. --- ## Tensor ```c #include "max/c/tensor.h" ``` **Functions:** ### `M_newTensorSpec()` > [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*M\_newTensorSpec(const int64\_t \*shape, int64\_t rankSize, [M\_Dtype](types.md#_CPPv47M_Dtype) dtype, const char \*tensorName, const [M\_Device](types.md#_CPPv48M_Device) \*device) Creates a tensor specification. You need this in order to set the input tensors with [`M_borrowTensorInto()`](#tensor_8h_1a58a1646cfa1726b047b020c89eb7345c). When storing tensor data in memory, we always use a diminishing stride size. That is, earlier dimensions in the shape have larger strides than later dimensions. For example, a C array declared as `int arr[1][2][3]` would have a shape specified as `{1, 2, 3}`. * **Parameters:** * shape – The shape of the tensor. * rankSize – The rank size of the tensor. * dtype – The datatype for the tensor. * tensorName – The name for the tensor. This string gets copied as part of the operation of `M_newTensorSpec`, so your original string need not remain valid after the completion of this call. * device – The device on which the tensor resides. * **Returns:** A pointer to the tensor spec. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeTensorSpec()`](#tensor_8h_1af0b957daeba1760134c3f24079b53026). ### `M_isDynamicRanked()` > int M\_isDynamicRanked(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Returns if the given spec has a dynamic rank. * **Parameters:** spec – The tensor spec. * **Returns:** `1` if the rank is dynamic. `0` otherwise. ### `M_getDimAt()` > int64\_t M\_getDimAt(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec, size\_t axis) Gets the element at a particular axis. * **Parameters:** * spec – The tensor spec. * axis – The requested axis * **Returns:** The dimension at requested axis if the spec and axis are valid and has static rank. Otherwise, `0`. A dimension equaling `kDynamicDimensionValue` indicates dynamic dimension e.g. batch-size of a model expecting a batched tensor. ### `M_getRank()` > int64\_t M\_getRank(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Gets the rank from the tensor spec. * **Parameters:** spec – The tensor spec. * **Returns:** The number of dimensions in the tensor spec if the spec is static and valid, `kDynamicRankValue` if dynamic. Otherwise, `0`. ### `M_getDtype()` > [M\_Dtype](types.md#_CPPv47M_Dtype) M\_getDtype(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Gets the datatype from the tensor spec. * **Parameters:** spec – The tensor spec. * **Returns:** The element type from the tensor spec if the tensor spec is valid. Otherwise, `M_UNKNOWN`. ### `M_getName()` > const char \*M\_getName([M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Gets the name of the tensor from the tensor spec. * **Parameters:** spec – The tensor spec. * **Returns:** A null-terminated string containing the name of the tensor if the `spec` is valid. Otherwise, `NULL`. The memory associated with the returned string is owned by `spec`. ### `M_newAsyncTensorMap()` > [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*M\_newAsyncTensorMap(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context) Creates a map of tensor names to async tensors. * **Parameters:** context – The runtime context. * **Returns:** A pointer to the tensor map. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeAsyncTensorMap()`](#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e). ### `M_borrowTensorInto()` > void M\_borrowTensorInto([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensors, void \*input, const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*tensorSpec, [M\_Status](types.md#_CPPv48M_Status) \*status) Adds a tensor to the tensor map. You are responsible for the lifetime of the input tensor data. Its data gets “borrowed” into the Tensor Map. * **Parameters:** * tensors – The tensor map, from [`M_newAsyncTensorMap()`](#tensor_8h_1a18039c6e6c1769b947120b27178306eb). * input – The input tensor data. * tensorSpec – The tensor spec, from [`M_newTensorSpec()`](#tensor_8h_1ab7546d4d0a22ae82134d200272e8f8f4). This gets copied as part of the operation of `M_borrowTensorInto`, so your original tensorSpec need not exist through the lifetime of the tensor map. * status – The status object for reporting errors. ### `M_getTensorByNameFrom()` > [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*M\_getTensorByNameFrom([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensorMap, const char \*name, [M\_Status](types.md#_CPPv48M_Status) \*status) Gets a tensor from the tensor map by name. * **Parameters:** * tensorMap – The tensor map. * name – The name of the tensor. * status – The status object for reporting errors. * **Returns:** A pointer to the tensor. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeTensor()`](#tensor_8h_1a339008df4a10af5e8c01ae970598765c). The held tensor inside the return value is simply borrowed from the corresponding input `M_AsyncTensorMap`. If the tensor map or name are invalid, a `NULL` pointer is returned and the `status` parameter contains an error message. ### `M_getTensorNumElements()` > size\_t M\_getTensorNumElements(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Gets the number of elements for the tensor. * **Parameters:** tensor – The tensor which must not be `NULL`. * **Returns:** The number of elements for the given tensor. ### `M_getTensorType()` > [M\_Dtype](types.md#_CPPv47M_Dtype) M\_getTensorType(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Gets the corresponding `M_Dtype` for the tensor. * **Parameters:** tensor – The tensor which must not be `NULL`. * **Returns:** The corresponding `M_Dtype` for the tensor. ### `M_getTensorData()` > const void \*M\_getTensorData(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Gets a pointer to underlying data of the tensor. * **Parameters:** tensor – The tensor which must not be `NULL`. * **Returns:** A pointer to the underlying data of the tensor. This pointer is valid for the lifetime of the underlying tensor. ### `M_getTensorSpec()` > [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*M\_getTensorSpec(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Gets a Tensor Spec for the tensor. * **Parameters:** tensor – The tensor. * **Returns:** The tensor spec for the tensor if the tensor is valid. Otherwise, `NULL`. ### `M_getDeviceTypeFromSpec()` > [M\_DeviceType](types.md#_CPPv412M_DeviceType) M\_getDeviceTypeFromSpec(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Gets the device type from a tensor specification. * **Parameters:** spec – The tensor spec. * **Returns:** The device type (CPU or GPU). ### `M_getDeviceIdFromSpec()` > int M\_getDeviceIdFromSpec(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Gets the device ID from a tensor specification. * **Parameters:** spec – The tensor spec. * **Returns:** The device ID. Returns `0` if the spec is invalid. ### `M_getTensorDevice()` > [M\_Device](types.md#_CPPv48M_Device) \*M\_getTensorDevice(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Gets the device on which a tensor resides. * **Parameters:** tensor – The tensor. * **Returns:** The device on which the tensor resides, or `NULL` if the tensor is invalid. The caller owns the returned device and must free it with `M_freeDevice()`. ### `M_copyTensorToDevice()` > [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*M\_copyTensorToDevice([M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor, [M\_Device](types.md#_CPPv48M_Device) \*device, [M\_Status](types.md#_CPPv48M_Status) \*status) Copies a tensor to a different device. Creates a copy of the tensor on the specified device. * **Parameters:** * tensor – The tensor to copy. * device – The target device. * status – The status object for reporting errors. * **Returns:** A pointer to the tensor on the target device. The caller owns the returned memory and must deallocate it by calling [`M_freeTensor()`](#tensor_8h_1a339008df4a10af5e8c01ae970598765c). Returns `NULL` if the operation fails, with an error message in the status. ### `M_freeTensor()` > void M\_freeTensor([M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor) Deallocates the memory for the tensor. No-op if `tensor` is NULL. * **Parameters:** tensor – The tensor to deallocate. ### `M_freeTensorNameArray()` > void M\_freeTensorNameArray([M\_TensorNameArray](types.md#_CPPv417M_TensorNameArray) \*names) Deallocates the memory for the array of tensor names. No-op if `names` is `NULL`. * **Parameters:** names – The tensor names to deallocate. ### `M_freeTensorSpec()` > void M\_freeTensorSpec([M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec) Deallocates the memory for the tensor spec. No-op if `spec` is `NULL`. * **Parameters:** spec – The tensor spec to deallocate. ### `M_freeAsyncTensorMap()` > void M\_freeAsyncTensorMap([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensorMap) Deallocates the memory for the tensor map. No-op if `tensorMap` is `NULL`. * **Parameters:** tensorMap – The tensor map to deallocate. --- ## Types ```c #include "max/c/types.h" ``` **Typedefs:** ### `M_Status` > typedef struct [M\_Status](#_CPPv48M_Status) M\_Status Contains the success or failure of an API call. In general, any API that may fail accepts a `M_Status` argument that is filled in with a meaningful error message on failure. You can create this with [`M_newStatus()`](common.md#common_8h_1adb1ef3fc2e0bcdc8eb17cac3ce91835b). When you’re done, call [`M_freeStatus()`](common.md#common_8h_1ab5067fd51a5696b3679f7f629d3329c4). ### `M_RuntimeConfig` > typedef struct [M\_RuntimeConfig](#_CPPv415M_RuntimeConfig) M\_RuntimeConfig Specifies the MAX Engine configuration. Configuration properties include the number of threads, artifact path, etc. You can create this with [`M_newRuntimeConfig()`](context.md#context_8h_1a963f1d4eefd812ba8691acf516007cfc). When you’re done, call [`M_freeRuntimeConfig()`](context.md#context_8h_1a47f7e22f7f71da9ab5fb3a1886911610). ### `M_RuntimeContext` > typedef struct [M\_RuntimeContext](#_CPPv416M_RuntimeContext) M\_RuntimeContext Contains information that needs to be shared between APIs. You can create this with [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). When you’re done, call [`M_freeRuntimeContext()`](context.md#context_8h_1a2434a11d8d65890c66f6b5516243a730). ### `M_CompileConfig` > typedef struct [M\_CompileConfig](#_CPPv415M_CompileConfig) M\_CompileConfig Specifies the configuration required for model compilation. You can create this with [`M_newCompileConfig()`](model.md#model_8h_1a417e7a581c096ca26c36a1875163b665). When you’re done, call [`M_freeCompileConfig()`](model.md#model_8h_1abbf74b13adaf5bc8a0bb4d46c40688d9). ### `M_AsyncCompiledModel` > typedef struct [M\_AsyncCompiledModel](#_CPPv420M_AsyncCompiledModel) M\_AsyncCompiledModel Contains an async value to a compiled model. `M_AsyncCompiledModel` can be passed to other APIs that accept compiled models as a function parameter. This async value will eventually resolve to a compiled model or an error in the case of compilation failure. You can create this with [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750). When you’re done, call [`M_freeCompiledModel()`](model.md#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c). ### `M_AsyncModel` > typedef struct [M\_AsyncModel](#_CPPv412M_AsyncModel) M\_AsyncModel Contains a future used for inference. The future will resolve to a model that’s ready for inference. You can create this with [`M_initModel()`](model.md#model_8h_1a2dcb9570ae117602579182d8faed494a). When you’re done, call [`M_freeModel()`](model.md#model_8h_1a4094fa8e414f8b6a6563474f8840d33c). ### `M_AsyncTensor` > typedef struct [M\_AsyncTensor](#_CPPv413M_AsyncTensor) M\_AsyncTensor Contains an async value to a tensor for inference. You can get this from [`M_getTensorByNameFrom()`](tensor.md#tensor_8h_1a9522ad955454dbd2d044066dea2cad95). When you’re done, call [`M_freeTensor()`](tensor.md#tensor_8h_1a339008df4a10af5e8c01ae970598765c). ### `M_TensorNameArray` > typedef struct [M\_TensorNameArray](#_CPPv417M_TensorNameArray) M\_TensorNameArray Contains an array of tensor names of model inputs or outputs. You can get this from `M_getInputNames()` and `M_getOutputNames()`. When you’re done, call [`M_freeTensorNameArray()`](tensor.md#tensor_8h_1a7fa5d2aff7f89143ae1905fc29b5b112). ### `M_TensorSpec` > typedef struct [M\_TensorSpec](#_CPPv412M_TensorSpec) M\_TensorSpec Contains the representation of a shape and an element type. You can create this with [`M_newTensorSpec()`](tensor.md#tensor_8h_1ab7546d4d0a22ae82134d200272e8f8f4). When you’re done, call [`M_freeTensorSpec()`](tensor.md#tensor_8h_1af0b957daeba1760134c3f24079b53026). ### `M_AsyncTensorMap` > typedef struct [M\_AsyncTensorMap](#_CPPv416M_AsyncTensorMap) M\_AsyncTensorMap Contains a collection of tensors. The collection of tensors is used to represent inputs and outputs when executing a model. You can create this with [`M_newAsyncTensorMap()`](tensor.md#tensor_8h_1a18039c6e6c1769b947120b27178306eb). When you’re done, call [`M_freeAsyncTensorMap()`](tensor.md#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e). ### `M_WeightsRegistry` > typedef struct [M\_WeightsRegistry](#_CPPv417M_WeightsRegistry) M\_WeightsRegistry Maps unique weight names to their backing data. ### `M_Device` > typedef struct [M\_Device](#_CPPv48M_Device) M\_Device Contains a device handle. A device represents a computational unit (CPU or GPU) that can execute operations and hold tensors. You can create this with `M_newDevice()`. When you’re done, call `M_freeDevice()`. **Enums:** ### `M_Dtype` > enum M\_Dtype Represents all data types supported by the framework. Values: #### `M_UNKNOWN` > enumerator M\_UNKNOWN #### `mIsInteger` > enumerator mIsInteger #### `mIsFloat` > enumerator mIsFloat #### `mIsComplex` > enumerator mIsComplex #### `mIsSigned` > enumerator mIsSigned Bit 0 encodes “isSigned”. #### `kIntWidthShift` > enumerator kIntWidthShift #### `M_INT1` > enumerator M\_INT1 #### `M_UINT1` > enumerator M\_UINT1 #### `M_INT2` > enumerator M\_INT2 #### `M_UINT2` > enumerator M\_UINT2 #### `M_INT4` > enumerator M\_INT4 #### `M_UINT4` > enumerator M\_UINT4 #### `M_INT8` > enumerator M\_INT8 #### `M_UINT8` > enumerator M\_UINT8 #### `M_INT16` > enumerator M\_INT16 #### `M_UINT16` > enumerator M\_UINT16 #### `M_INT32` > enumerator M\_INT32 #### `M_UINT32` > enumerator M\_UINT32 #### `M_INT64` > enumerator M\_INT64 #### `M_UINT64` > enumerator M\_UINT64 #### `M_INT128` > enumerator M\_INT128 #### `M_UINT128` > enumerator M\_UINT128 #### `M_FLOAT4_E2M1FN` > enumerator M\_FLOAT4\_E2M1FN Bits 0 through 3 indicate the kind of FP value. #### `M_FLOAT8_E8M0FNU` > enumerator M\_FLOAT8\_E8M0FNU Some slots are left blank here to enable us to support more lower precision types in the future. #### `M_FLOAT8_E3M4` > enumerator M\_FLOAT8\_E3M4 #### `M_FLOAT8_E4M3FN` > enumerator M\_FLOAT8\_E4M3FN #### `M_FLOAT8_E4M3FNUZ` > enumerator M\_FLOAT8\_E4M3FNUZ #### `M_FLOAT8_E5M2` > enumerator M\_FLOAT8\_E5M2 #### `M_FLOAT8_E5M2FNUZ` > enumerator M\_FLOAT8\_E5M2FNUZ #### `M_FLOAT16` > enumerator M\_FLOAT16 #### `M_BFLOAT16` > enumerator M\_BFLOAT16 #### `M_FLOAT32` > enumerator M\_FLOAT32 #### `M_FLOAT64` > enumerator M\_FLOAT64 #### `M_BOOL` > enumerator M\_BOOL ### `M_AllocatorType` > enum M\_AllocatorType Contains an `AllocatorType`. You can choose between kCaching and kSystem kCaching trades off higher memory usage for better performance. kSystem uses the default system allocator. Values: #### `kSystem` > enumerator kSystem #### `kCaching` > enumerator kCaching ### `M_ValueType` > enum M\_ValueType Represents the type of a value. Values: #### `M_STRING_VALUE` > enumerator M\_STRING\_VALUE #### `M_DOUBLE_VALUE` > enumerator M\_DOUBLE\_VALUE #### `M_LONG_VALUE` > enumerator M\_LONG\_VALUE #### `M_BOOL_VALUE` > enumerator M\_BOOL\_VALUE #### `M_TENSOR_VALUE` > enumerator M\_TENSOR\_VALUE #### `M_LIST_VALUE` > enumerator M\_LIST\_VALUE #### `M_TUPLE_VALUE` > enumerator M\_TUPLE\_VALUE #### `M_DICT_VALUE` > enumerator M\_DICT\_VALUE #### `M_NONE_VALUE` > enumerator M\_NONE\_VALUE #### `M_UNKNOWN_VALUE` > enumerator M\_UNKNOWN\_VALUE #### `M_MOJO_VALUE` > enumerator M\_MOJO\_VALUE #### `M_PYTHON_MOJO_VALUE` > enumerator M\_PYTHON\_MOJO\_VALUE ### `M_DeviceType` > enum M\_DeviceType Represents the type of device. Values: #### `M_HOST` > enumerator M\_HOST #### `M_ACCELERATOR` > enumerator M\_ACCELERATOR ### `M_ResultOutputStyle` > enum M\_ResultOutputStyle Represents the result output style for debug printing. Values: #### `M_COMPACT` > enumerator M\_COMPACT #### `M_FULL` > enumerator M\_FULL #### `M_BINARY` > enumerator M\_BINARY #### `M_BINARY_MAX_CHECKPOINT` > enumerator M\_BINARY\_MAX\_CHECKPOINT #### `M_NONE` > enumerator M\_NONE --- ## API references import ListingCards from '@site/src/components/Listing/ListingCards'; export const cards = [ { title: 'Python', url: '/max/api/python', description: 'The Python library API reference.' }, { title: 'Mojo', url: '/mojo/lib', description: 'The Mojo library API reference.' }, { title: 'REST', url: '/max/api/serve', description: 'The MAX serving REST API reference.' } ] --- ## BackgroundRecorder ## `BackgroundRecorder` {#max.diagnostics.gpu.BackgroundRecorder} > class max.diagnostics.gpu.BackgroundRecorder Asynchronous GPU metrics collection and data export capabilities. The `BackgroundRecorder` enables continuous monitoring of GPU performance metrics without blocking the main application thread. It automatically samples GPU statistics at one-second intervals in a separate process, making it ideal for profiling long-running inference sessions or training workloads. When used as a context manager, the recorder starts background collection upon entry and stops collection upon exit. The collected statistics are then available through the stats property as a time-series of GPU measurements. ```python from max.diagnostics.gpu import BackgroundRecorder with BackgroundRecorder() as recorder: # Run your GPU workload here run_inference_session() for i, snapshot in enumerate(recorder.stats): print(f"Sample {i}: {len(snapshot)} GPUs detected") for gpu_id, gpu_stats in snapshot.items(): print(f" {gpu_id}: {gpu_stats.memory.used_bytes} bytes used") ``` ### `stats` {#max.diagnostics.gpu.BackgroundRecorder.stats} > property stats: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [GPUStats](GPUStats.md#max.diagnostics.gpu.GPUStats)]] Time-series of GPU statistics collected during background recording.
**Returns:**
A list of dictionaries, where each dictionary represents GPU statistics at a specific point in time. Each dictionary maps GPU identifiers to their corresponding [`GPUStats`](GPUStats.md#max.diagnostics.gpu.GPUStats) objects.
**Raises:**
[RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If accessed before the recorder context has exited.
--- ## GPUDiagContext ## `GPUDiagContext` {#max.diagnostics.gpu.GPUDiagContext} > class max.diagnostics.gpu.GPUDiagContext Context manager providing unified access to GPU diagnostic information across NVIDIA and AMD hardware. This class automatically detects and initializes supported GPU vendor libraries (NVML for NVIDIA, ROCm SMI for AMD) and provides a unified interface for collecting diagnostic statistics from all available GPUs in the system. ```python from max.diagnostics.gpu import GPUDiagContext with GPUDiagContext() as ctx: stats = ctx.get_stats() for gpu_id, gpu_stats in stats.items(): print(f"GPU {gpu_id}: {gpu_stats.memory.used_bytes} bytes used") ``` ### `get_stats()` {#max.diagnostics.gpu.GPUDiagContext.get_stats} > get\_stats() Retrieve current GPU statistics for all detected GPUs in the system.
**Returns:**
A dictionary mapping GPU identifiers to their current statistics. NVIDIA GPUs are prefixed with `nv` (e.g., `nv0`, `nv1`) and AMD GPUs are prefixed with `amd` (e.g., `amd0`, `amd1`).
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [GPUStats](GPUStats.md#max.diagnostics.gpu.GPUStats)]
--- ## GPUStats ## `GPUStats` {#max.diagnostics.gpu.GPUStats} > class max.diagnostics.gpu.GPUStats(memory, utilization) Comprehensive GPU state snapshot containing memory and utilization metrics. This class provides a complete view of a GPU’s current state, including detailed memory usage statistics and utilization percentages. It serves as the primary data structure returned by GPU diagnostic queries.
**Parameters:**
* memory ([MemoryStats](MemoryStats.md#max.diagnostics.gpu.MemoryStats)) * utilization ([UtilizationStats](UtilizationStats.md#max.diagnostics.gpu.UtilizationStats))
### `memory` {#max.diagnostics.gpu.GPUStats.memory} > memory: [MemoryStats](MemoryStats.md#max.diagnostics.gpu.MemoryStats) Current GPU compute and memory utilization percentages. --- ## MemoryStats ## `MemoryStats` {#max.diagnostics.gpu.MemoryStats} > class max.diagnostics.gpu.MemoryStats(total\_bytes, free\_bytes, used\_bytes, reserved\_bytes) Detailed GPU memory usage statistics including total, free, used, and reserved memory. This class provides comprehensive memory information for a GPU, allowing developers to monitor memory consumption and identify potential memory bottlenecks during model inference or training.
**Parameters:**
* total\_bytes ([int](https://docs.python.org/3/library/functions.html#int)) * free\_bytes ([int](https://docs.python.org/3/library/functions.html#int)) * used\_bytes ([int](https://docs.python.org/3/library/functions.html#int)) * reserved\_bytes ([int](https://docs.python.org/3/library/functions.html#int) | None)
### `free_bytes` {#max.diagnostics.gpu.MemoryStats.free_bytes} > free\_bytes: [int](https://docs.python.org/3/library/functions.html#int) Currently allocated GPU memory in bytes. ### `total_bytes` {#max.diagnostics.gpu.MemoryStats.total_bytes} > total\_bytes: [int](https://docs.python.org/3/library/functions.html#int) Currently available GPU memory in bytes. ### `used_bytes` {#max.diagnostics.gpu.MemoryStats.used_bytes} > used\_bytes: [int](https://docs.python.org/3/library/functions.html#int) Memory reserved by the driver, if available from the GPU vendor. --- ## UtilizationStats ## `UtilizationStats` {#max.diagnostics.gpu.UtilizationStats} > class max.diagnostics.gpu.UtilizationStats(gpu\_usage\_percent, memory\_activity\_percent) GPU compute and memory activity utilization percentages. This class captures the current utilization levels of a GPU’s compute units and memory subsystem, providing insights into how effectively the GPU resources are being utilized during workload execution.
**Parameters:**
* gpu\_usage\_percent ([int](https://docs.python.org/3/library/functions.html#int)) * memory\_activity\_percent ([int](https://docs.python.org/3/library/functions.html#int) | None)
### `gpu_usage_percent` {#max.diagnostics.gpu.UtilizationStats.gpu_usage_percent} > gpu\_usage\_percent: [int](https://docs.python.org/3/library/functions.html#int) Current GPU compute utilization as a percentage (0-100). ### `memory_activity_percent` {#max.diagnostics.gpu.UtilizationStats.memory_activity_percent} > memory\_activity\_percent: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) Memory controller activity percentage, if available from the GPU vendor. --- ## gpu Real-time GPU monitoring and diagnostic capabilities for NVIDIA and AMD graphics hardware. The GPU diagnostics module provides comprehensive tools for monitoring graphics processing unit performance, memory usage, and utilization metrics. It supports both NVIDIA GPUs through NVML and AMD GPUs through ROCm SMI, offering unified access to hardware statistics regardless of vendor. The API enables both synchronous queries for immediate metrics and asynchronous background collection for continuous monitoring during long-running inference sessions. ## Classes * [`BackgroundRecorder`](/max/api/python/diagnostics/gpu/BackgroundRecorder): Asynchronous GPU metrics collection. * [`GPUDiagContext`](/max/api/python/diagnostics/gpu/GPUDiagContext): Context manager providing unified access to GPU diagnostic information across NVIDIA and AMD hardware. * [`GPUStats`](/max/api/python/diagnostics/gpu/GPUStats): Comprehensive GPU state snapshot containing memory and utilization statistics. * [`MemoryStats`](/max/api/python/diagnostics/gpu/MemoryStats): Detailed GPU memory usage statistics including total, free, used, and reserved memory. * [`UtilizationStats`](/max/api/python/diagnostics/gpu/UtilizationStats): GPU compute and memory activity utilization percentages. --- ## driver Exposes APIs for interacting with hardware, such as allocating tensors on a GPU and moving tensors between the CPU and GPU. It provides interfaces for memory management, device properties, and hardware monitoring. Through these APIs, you can control data placement, track resource utilization, and configure device settings for optimal performance. For example, you can use the following code to use an accelerator if one is available, otherwise use the CPU: ```python from max import driver device = driver.CPU() if driver.accelerator_count() == 0 else driver.Accelerator() print(f"Using {device} device") ``` ## `Accelerator` {#max.driver.Accelerator} > class max.driver.Accelerator(\*args, \*\*kwargs) ## `Buffer` {#max.driver.Buffer} > class max.driver.Buffer(\*args, \*\*kwargs) Device-resident buffer representation. Allocates memory onto a given device with the provided shape and dtype. Buffers can be sliced to provide strided views of the underlying memory, but any buffers input into model execution must be contiguous. Supports numpy-style slicing but does not currently support setting items across multiple indices. ```python from max import driver from max.dtype import DType cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.float32) # Create a buffer on GPU gpu = driver.Accelerator() gpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.float32, device=gpu) ```
**Parameters:**
* dtype ([DType](dtype.md#max.dtype.DType)) – Data type of buffer elements. * shape (Sequence\[[int](https://docs.python.org/3/library/functions.html#int)]) – Tuple of positive, non-zero integers denoting the buffer shape. * device ([Device](#max.driver.Device), optional) – Device to allocate buffer onto. Defaults to the CPU. * pinned ([bool](https://docs.python.org/3/library/functions.html#bool), optional) – If True, memory is page-locked (pinned). Defaults to False. * stream ([DeviceStream](#max.driver.DeviceStream), optional) – Stream to associate the buffer with.
### `contiguous()` {#max.driver.Buffer.contiguous} > contiguous() Creates a contiguous copy of the parent buffer.
**Parameters:**
self ([Buffer](#max.driver.Buffer))
**Return type:**
[Buffer](#max.driver.Buffer)
### `copy()` {#max.driver.Buffer.copy} > copy(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [max.driver.Buffer](#max.driver.Buffer) > copy(self, device: [max.driver.Device](#max.driver.Device) | [None](https://docs.python.org/3/library/constants.html#None) = None) → [max.driver.Buffer](#max.driver.Buffer) Overloaded function. 1. `copy(self, stream: max.driver.DeviceStream) -> max.driver.Buffer` > Creates a deep copy on the device associated with the stream. > Args: > : stream (DeviceStream): The stream to associate the new buffer with. > Returns: > : Buffer: A new buffer that is a copy of this buffer. 2. `copy(self, device: max.driver.Device | None = None) -> max.driver.Buffer` > Creates a deep copy on an optionally given device. > If device is None (default), a copy is created on the same device. > > ```python > from max import driver > from max.dtype import DType > ​ > cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU()) > cpu_copy = cpu_buffer.copy() > ​ > # Copy to GPU > gpu = driver.Accelerator() > gpu_copy = cpu_buffer.copy(device=gpu) > ``` > Args: > : device (Device, optional): The device to create the copy on. > : Defaults to None (same device). > Returns: > : Buffer: A new buffer that is a copy of this buffer. ### `device` {#max.driver.Buffer.device} > property device Device on which tensor is resident. ### `disable_auto_sync()` {#max.driver.Buffer.disable_auto_sync} > disable\_auto\_sync(self) → [None](https://docs.python.org/3/library/constants.html#None) Disables automatic synchronization for asynchronous operations on this buffer. :::caution Caution This is an experimental feature that may be unstable. It also requires special care from the user to ensure proper synchronization. ::: By default, certain operations on buffers cause synchronization, such as when trying to access a buffer on the host through to\_numpy. However the default synchronization is quite conservative and often ends up waiting on more than what is strictly needed. This function disables the default synchronization method and enables mark\_as\_ready(), which allows for a finer control of what is waited on when a buffer needs to be synchronized. ```python # Assuming we have 3 buffers of the same sizes, a, b and c # Default case with auto-synchronization a.to(b) # 1 a.to(c) # 2 # Will wait on 1 and 2 b.to_numpy() # Disabled synchronization a.disable_auto_sync() a.to(b) # 1 a.to(c) # 2 # Doesn't wait on 1 or 2, data in b could be invalid b.to_numpy() # Disabled synchronization with mark_as_ready a.disable_auto_sync() a.to(b) # 1 b.mark_as_ready() a.to(c) # 2 # Wait on 1 but not on 2 b.to_numpy() ``` ### `dtype` {#max.driver.Buffer.dtype} > property dtype DType of constituent elements in tensor. ### `element_size` {#max.driver.Buffer.element_size} > property element\_size Return the size of the element type in bytes. ### `from_dlpack()` {#max.driver.Buffer.from_dlpack} > from\_dlpack(\*, copy=None) Create a buffer from an object implementing the dlpack protocol. This usually does not result in a copy, and the producer of the object retains ownership of the underlying memory.
**Parameters:**
* array ([Any](https://docs.python.org/3/library/typing.html#typing.Any)) * copy ([bool](https://docs.python.org/3/library/functions.html#bool) | None)
**Return type:**
[Buffer](#max.driver.Buffer)
### `from_numpy()` {#max.driver.Buffer.from_numpy} > from\_numpy() Creates a buffer from a provided numpy array on the host device. The underlying data is not copied unless the array is noncontiguous. If it is, a contiguous copy will be returned.
**Parameters:**
arr ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
**Return type:**
[Buffer](#max.driver.Buffer)
### `inplace_copy_from()` {#max.driver.Buffer.inplace_copy_from} > inplace\_copy\_from(src) Copy the contents of another buffer into this one. These buffers may be on different devices. Requires that both buffers are contiguous and have same size.
**Parameters:**
* self ([Buffer](#max.driver.Buffer)) * src ([Buffer](#max.driver.Buffer))
**Return type:**
None
### `is_contiguous` {#max.driver.Buffer.is_contiguous} > property is\_contiguous Whether or not buffer is contiguously allocated in memory. Returns false if the buffer is a non-contiguous slice. Currently, we consider certain situations that are contiguous as non-contiguous for the purposes of our engine, such as when a buffer has negative steps. ### `is_host` {#max.driver.Buffer.is_host} > property is\_host Whether or not buffer is host-resident. Returns false for GPU buffers, true for CPU buffers. ```python from max import driver from max.dtype import DType cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU()) print(cpu_buffer.is_host) ``` ### `item()` {#max.driver.Buffer.item} > item(self) → [Any](https://docs.python.org/3/library/typing.html#typing.Any) Returns the scalar value at a given location. Currently implemented only for zero-rank buffers. The return type is converted to a Python built-in type. ### `mark_as_ready()` {#max.driver.Buffer.mark_as_ready} > mark\_as\_ready(self) → [None](https://docs.python.org/3/library/constants.html#None) Establishes a synchronization point for buffers with disabled auto-sync. :::caution Caution This is an experimental feature that may be unstable. It also requires special care from the user to ensure proper synchronization. ::: This method can only be called on buffers with disabled synchronization through disable\_auto\_sync(). It instructs max that whenever it needs to wait on this buffer it should only wait to the point where this was called. It can be called multiple times, but it will override a previous synchronization point with the new one. Refer to the disable\_auto\_sync() documentation for more details and examples. ### `mmap()` {#max.driver.Buffer.mmap} > mmap(dtype, shape, mode='copyonwrite', offset=0)
**Parameters:**
* filename (PathLike\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [str](https://docs.python.org/3/library/stdtypes.html#str)) * dtype ([DType](dtype.md#max.dtype.DType)) * shape (ShapeType | [int](https://docs.python.org/3/library/functions.html#int)) * mode (np.\_MemMapModeKind) * offset ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[Buffer](#max.driver.Buffer)
### `num_elements` {#max.driver.Buffer.num_elements} > property num\_elements Returns the number of elements in this buffer. Rank-0 buffers have 1 element by convention. ### `pinned` {#max.driver.Buffer.pinned} > property pinned Whether or not the underlying memory is pinned (page-locked). ### `rank` {#max.driver.Buffer.rank} > property rank Buffer rank. ### `scalar` {#max.driver.Buffer.scalar} > scalar = \ ### `shape` {#max.driver.Buffer.shape} > property shape Shape of buffer. ### `stream` {#max.driver.Buffer.stream} > property stream Stream to which tensor is bound. ### `to()` {#max.driver.Buffer.to} > to(self, device: [max.driver.Device](#max.driver.Device)) → [max.driver.Buffer](#max.driver.Buffer) > to(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [max.driver.Buffer](#max.driver.Buffer) > to(self, devices: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[max.driver.Device](#max.driver.Device)]) → [list](https://docs.python.org/3/library/stdtypes.html#list)\[[max.driver.Buffer](#max.driver.Buffer)] > to(self, streams: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[max.driver.DeviceStream](#max.driver.DeviceStream)]) → [list](https://docs.python.org/3/library/stdtypes.html#list)\[[max.driver.Buffer](#max.driver.Buffer)] Overloaded function. 1. `to(self, device: max.driver.Device) -> max.driver.Buffer` > Return a buffer that’s guaranteed to be on the given device. > The buffer is only copied if the requested device is different from the > device upon which the buffer is already resident. 2. `to(self, stream: max.driver.DeviceStream) -> max.driver.Buffer` > Return a buffer that’s guaranteed to be on the given device and associated > with the given stream. > The buffer is only copied if the requested device is different from the > device upon which the buffer is already resident. If the destination > stream is on the same device, then a new reference to the same buffer is > returned. 3. `to(self, devices: collections.abc.Sequence[max.driver.Device]) -> list[max.driver.Buffer]` > Return a list of buffers that are guaranteed to be on the given devices. > The buffers are only copied if the requested devices are different from the > device upon which the buffer is already resident. 4. `to(self, streams: collections.abc.Sequence[max.driver.DeviceStream]) -> list[max.driver.Buffer]` > Return a list of buffers that are guaranteed to be on the given streams. > The buffers are only copied if the requested streams are different from the > stream upon which the buffer is already resident. ### `to_numpy()` {#max.driver.Buffer.to_numpy} > to\_numpy() Converts the buffer to a numpy array. If the buffer is not on the host, a copy will be issued.
**Parameters:**
self ([Buffer](#max.driver.Buffer))
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
### `view()` {#max.driver.Buffer.view} > view(dtype, shape=None) Return a new buffer with the given type and shape that shares the underlying memory. If the shape is not given, it will be deduced if possible, or a ValueError is raised.
**Parameters:**
* self ([Buffer](#max.driver.Buffer)) * dtype ([DType](dtype.md#max.dtype.DType)) * shape ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] | None)
**Return type:**
[Buffer](#max.driver.Buffer)
### `zeros` {#max.driver.Buffer.zeros} > zeros = \ ## `CPU` {#max.driver.CPU} > class max.driver.CPU(\*args, \*\*kwargs) ## `DLPackArray` {#max.driver.DLPackArray} > class max.driver.DLPackArray(\*args, \*\*kwargs) ## `Device` {#max.driver.Device} > class max.driver.Device ### `api` {#max.driver.Device.api} > property api Returns the API used to program the device. Possible values are: * `cpu` for host devices. * `cuda` for NVIDIA GPUs. * `hip` for AMD GPUs. ```python from max import driver device = driver.CPU() device.api ``` ### `architecture_name` {#max.driver.Device.architecture_name} > property architecture\_name Returns the architecture name of the device. Examples of possible values: * `gfx90a`, `gfx942` for AMD GPUs. * `sm_80`, `sm_86` for NVIDIA GPUs. * CPU devices raise an exception. ```python from max import driver device = driver.Accelerator() device.architecture_name ``` ### `can_access()` {#max.driver.Device.can_access} > can\_access(self, other: [max.driver.Device](#max.driver.Device)) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if this device can directly access memory of another device. ```python from max import driver gpu0 = driver.Accelerator(id=0) gpu1 = driver.Accelerator(id=1) if gpu0.can_access(gpu1): print("GPU0 can directly access GPU1 memory.") ```
**Parameters:**
other ([Device](#max.driver.Device)) – The other device to check peer access against.
**Returns:**
True if peer access is possible, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `cpu` {#max.driver.Device.cpu} > cpu = \ ### `default_stream` {#max.driver.Device.default_stream} > property default\_stream Returns the default stream for this device. The default stream is initialized when the device object is created.
**Returns:**
The default execution stream for this device.
**Return type:**
[DeviceStream](#max.driver.DeviceStream)
### `id` {#max.driver.Device.id} > property id Returns a zero-based device id. For a CPU device this is always 0. For GPU accelerators this is the id of the device relative to this host. Along with the `label`, an id can uniquely identify a device, e.g. `gpu:0`, `gpu:1`. ```python from max import driver device = driver.Accelerator() device_id = device.id ```
**Returns:**
The device ID.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `is_compatible` {#max.driver.Device.is_compatible} > property is\_compatible Returns whether this device is compatible with MAX.
**Returns:**
True if the device is compatible with MAX, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `is_host` {#max.driver.Device.is_host} > property is\_host Whether this device is the CPU (host) device. ```python from max import driver device = driver.CPU() device.is_host ``` ### `label` {#max.driver.Device.label} > property label Returns device label. Possible values are: * `cpu` for host devices. * `gpu` for accelerators. ```python from max import driver device = driver.CPU() device.label ``` ### `stats` {#max.driver.Device.stats} > property stats Returns utilization data for the device. ```python from max import driver device = driver.CPU() stats = device.stats ```
**Returns:**
A dictionary containing device utilization statistics.
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)
### `synchronize()` {#max.driver.Device.synchronize} > synchronize(self) → [None](https://docs.python.org/3/library/constants.html#None) Ensures all operations on this device complete before returning.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If any enqueued operations had an internal error.
## `DeviceSpec` {#max.driver.DeviceSpec} > class max.driver.DeviceSpec(id, device\_type='cpu') Specification for a device, containing its ID and type. This class provides a way to specify device parameters like ID and type (CPU/GPU) for creating Device instances.
**Parameters:**
* id ([int](https://docs.python.org/3/library/functions.html#int)) * device\_type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['cpu', 'gpu'])
### `accelerator()` {#max.driver.DeviceSpec.accelerator} > static accelerator(id=0) Creates an accelerator (GPU) device specification.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
### `cpu()` {#max.driver.DeviceSpec.cpu} > static cpu(id=-1) Creates a CPU device specification.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
### `device_type` {#max.driver.DeviceSpec.device_type} > device\_type: [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['cpu', 'gpu'] = 'cpu' Type of specified device. ### `id` {#max.driver.DeviceSpec.id} > id: [int](https://docs.python.org/3/library/functions.html#int) Provided id for this device. ## `DeviceStream` {#max.driver.DeviceStream} > class max.driver.DeviceStream(\*args, \*\*kwargs) Provides access to a stream of execution on a device. A stream represents a sequence of operations that will be executed in order. Multiple streams on the same device can execute concurrently. ```python from max import driver # Create a default accelerator device device = driver.Accelerator() # Get the default stream for the device stream = device.default_stream # Create a new stream of execution on the device new_stream = driver.DeviceStream(device) ``` ### `device` {#max.driver.DeviceStream.device} > property device The device this stream is executing on. ### `synchronize()` {#max.driver.DeviceStream.synchronize} > synchronize(self) → [None](https://docs.python.org/3/library/constants.html#None) Ensures all operations on this stream complete before returning.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If any enqueued operations had an internal error.
### `wait_for()` {#max.driver.DeviceStream.wait_for} > wait\_for(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [None](https://docs.python.org/3/library/constants.html#None) > wait\_for(self, device: [max.driver.Device](#max.driver.Device)) → [None](https://docs.python.org/3/library/constants.html#None) Overloaded function. 1. `wait_for(self, stream: max.driver.DeviceStream) -> None` > Ensures all operations on the other stream complete before future work > submitted to this stream is scheduled. > Args: > : stream (DeviceStream): The stream to wait for. 2. `wait_for(self, device: max.driver.Device) -> None` > Ensures all operations on device’s default stream complete before > future work submitted to this stream is scheduled. > Args: > : device (Device): The device whose default stream to wait for. ## `accelerator_api()` {#max.driver.accelerator_api} > max.driver.accelerator\_api() Returns the API used to program the accelerator.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
## `accelerator_architecture_name()` {#max.driver.accelerator_architecture_name} > max.driver.accelerator\_architecture\_name() Returns the architecture name of the accelerator device.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
## `calculate_virtual_device_count()` {#max.driver.calculate_virtual_device_count} > max.driver.calculate\_virtual\_device\_count(\*device\_spec\_lists) Calculate the minimum virtual device count needed for the given device specs.
**Parameters:**
\*device\_spec\_lists ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](#max.driver.DeviceSpec)]) – One or more lists of DeviceSpec objects (e.g., main devices and draft devices)
**Returns:**
The minimum number of virtual devices needed (max GPU ID + 1), or 1 if no GPUs
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `calculate_virtual_device_count_from_cli()` {#max.driver.calculate_virtual_device_count_from_cli} > max.driver.calculate\_virtual\_device\_count\_from\_cli(\*device\_inputs) Calculate virtual device count from raw CLI inputs (before parsing). This helper works with the raw device input strings or lists before they’re parsed into DeviceSpec objects. Used when virtual device mode needs to be enabled before device validation occurs.
**Parameters:**
\*device\_inputs ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – One or more raw device inputs - either strings like “gpu:0,1,2” or lists of integers like \[0, 1, 2]
**Returns:**
The minimum number of virtual devices needed (max GPU ID + 1), or 1 if no GPUs
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `devices_exist()` {#max.driver.devices_exist} > max.driver.devices\_exist(devices) Identify if devices exist.
**Parameters:**
devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](#max.driver.DeviceSpec)])
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `load_devices()` {#max.driver.load_devices} > max.driver.load\_devices(device\_specs) Initialize and return a list of devices, given a list of device specs.
**Parameters:**
device\_specs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[DeviceSpec](#max.driver.DeviceSpec)])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](#max.driver.Device)]
## `load_max_buffer()` {#max.driver.load_max_buffer} > max.driver.load\_max\_buffer(path) Experimental method for loading serialized MAX buffers. Max buffers can be exported by creating a graph and calling Value.print() with the BINARY\_MAX\_CHECKPOINT option.
**Parameters:**
path ([PathLike](https://docs.python.org/3/library/os.html#os.PathLike)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – Path to buffer (should end with .max)
**Returns:**
A Buffer created from the path. The shape and dtype are read from the file.
**Raises:**
ValueError if the file format is not the MAX checkpoint format. –
**Return type:**
[Buffer](#max.driver.Buffer)
## `scan_available_devices()` {#max.driver.scan_available_devices} > max.driver.scan\_available\_devices() Returns all accelerators if available, else return cpu.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](#max.driver.DeviceSpec)]
## `accelerator_count()` {#max.driver.accelerator_count} > max.driver.accelerator\_count() → [int](https://docs.python.org/3/library/functions.html#int) Returns number of accelerator devices available. --- ## dtype Provides data type definitions for tensors in MAX Engine. These data types are essential for defining the precision and memory layout of tensor data when working with machine learning models. This module defines the [`DType`](#max.dtype.DType) enum, which represents all supported tensor data types in MAX Engine, including: * Integer types (signed and unsigned): `int8` | `uint8` | `int16` | `uint16` | `int32` | `uint32` | `int64` | `uint64` * Floating-point types (`float8` variants): `float16` | `bfloat16` | `float32` | `float64` * Boolean type: `bool` The module also provides utilities for converting between MAX Engine data types and [NumPy dtypes](https://numpy.org/doc/stable/user/basics.types.html), making it easy to interoperate with the NumPy ecosystem. ```python import numpy as np from max.dtype import DType tensor = np.zeros((2, 3), dtype=DType.float32.to_numpy()) # Convert NumPy dtype to MAX DType array = np.ones((4, 4), dtype=np.float16) max_dtype = DType.from_numpy(array.dtype) # Check properties of data types is_float = DType.float32.is_float() # True is_int = DType.int64.is_integral() # True size = DType.float64.size_in_bytes # 8 ``` ## `DType` {#max.dtype.DType} > class max.dtype.DType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) The tensor data type. ### `align` {#max.dtype.DType.align} > property align Returns the alignment requirement of the data type in bytes. The alignment specifies the memory boundary that values of this data type must be aligned to for optimal performance and correctness. ### `bfloat16` {#max.dtype.DType.bfloat16} > bfloat16 = 80 ### `bool` {#max.dtype.DType.bool} > bool = 1 ### `float16` {#max.dtype.DType.float16} > float16 = 79 ### `float32` {#max.dtype.DType.float32} > float32 = 81 ### `float4_e2m1fn` {#max.dtype.DType.float4_e2m1fn} > float4\_e2m1fn = 64 ### `float64` {#max.dtype.DType.float64} > float64 = 82 ### `float8_e4m3fn` {#max.dtype.DType.float8_e4m3fn} > float8\_e4m3fn = 75 ### `float8_e4m3fnuz` {#max.dtype.DType.float8_e4m3fnuz} > float8\_e4m3fnuz = 76 ### `float8_e5m2` {#max.dtype.DType.float8_e5m2} > float8\_e5m2 = 77 ### `float8_e5m2fnuz` {#max.dtype.DType.float8_e5m2fnuz} > float8\_e5m2fnuz = 78 ### `float8_e8m0fnu` {#max.dtype.DType.float8_e8m0fnu} > float8\_e8m0fnu = 73 ### `from_numpy()` {#max.dtype.DType.from_numpy} > from\_numpy() Converts a NumPy dtype to the corresponding DType.
**Parameters:**
dtype (np.dtype) – The NumPy dtype to convert.
**Returns:**
The corresponding DType enum value.
**Return type:**
[DType](#max.dtype.DType)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input dtype is not supported.
### `from_torch()` {#max.dtype.DType.from_torch} > from\_torch(\_error=None)
**Parameters:**
* dtype (dtype) * \_error ([Exception](https://docs.python.org/3/library/exceptions.html#Exception) | None)
**Return type:**
[DType](#max.dtype.DType)
### `int16` {#max.dtype.DType.int16} > int16 = 137 ### `int32` {#max.dtype.DType.int32} > int32 = 139 ### `int64` {#max.dtype.DType.int64} > int64 = 141 ### `int8` {#max.dtype.DType.int8} > int8 = 135 ### `is_float()` {#max.dtype.DType.is_float} > is\_float(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is a floating-point type. ### `is_float8()` {#max.dtype.DType.is_float8} > is\_float8(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is an 8-bit floating-point type. ### `is_half()` {#max.dtype.DType.is_half} > is\_half(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is a half-precision floating-point type. ### `is_integral()` {#max.dtype.DType.is_integral} > is\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is an integer type. ### `is_signed_integral()` {#max.dtype.DType.is_signed_integral} > is\_signed\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is a signed integer type. ### `is_unsigned_integral()` {#max.dtype.DType.is_unsigned_integral} > is\_unsigned\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool) Checks if the data type is an unsigned integer type. ### `size_in_bits` {#max.dtype.DType.size_in_bits} > property size\_in\_bits Returns the size of the data type in bits. This indicates how many bits are required to store a single value of this data type in memory. ### `size_in_bytes` {#max.dtype.DType.size_in_bytes} > property size\_in\_bytes Returns the size of the data type in bytes. This indicates how many bytes are required to store a single value of this data type in memory. ### `to_numpy()` {#max.dtype.DType.to_numpy} > to\_numpy() Converts this `DType` to the corresponding NumPy dtype.
**Returns:**
The corresponding NumPy dtype object.
**Return type:**
[DType](#max.dtype.DType)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the dtype is not supported.
**Parameters:**
self ([DType](#max.dtype.DType))
### `to_torch()` {#max.dtype.DType.to_torch} > to\_torch(\_error=None)
**Parameters:**
* dtype ([DType](#max.dtype.DType)) * \_error ([Exception](https://docs.python.org/3/library/exceptions.html#Exception) | None)
**Return type:**
dtype
### `uint16` {#max.dtype.DType.uint16} > uint16 = 136 ### `uint32` {#max.dtype.DType.uint32} > uint32 = 138 ### `uint64` {#max.dtype.DType.uint64} > uint64 = 140 ### `uint8` {#max.dtype.DType.uint8} > uint8 = 134 ## `finfo` {#max.dtype.finfo} > class max.dtype.finfo(dtype) Numerical properties of a floating point `max.dtype.DType`. This is modeled after `torch.finfo`, providing `bits`, `eps`, `max`, `min`, `tiny`, `smallest_normal`, and `dtype` attributes for every MAX float dtype—including bfloat16, float8, and float4 types that numpy cannot represent.
**Parameters:**
dtype ([DType](#max.dtype.DType)) – A floating-point `DType` to query.
**Raises:**
[TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If dtype is not a floating-point type.
### `bits` {#max.dtype.finfo.bits} > bits: [int](https://docs.python.org/3/library/functions.html#int) ### `dtype` {#max.dtype.finfo.dtype} > dtype: [DType](#max.dtype.DType) ### `eps` {#max.dtype.finfo.eps} > eps: [float](https://docs.python.org/3/library/functions.html#float) ### `max` {#max.dtype.finfo.max} > max: [float](https://docs.python.org/3/library/functions.html#float) ### `min` {#max.dtype.finfo.min} > min: [float](https://docs.python.org/3/library/functions.html#float) ### `smallest_normal` {#max.dtype.finfo.smallest_normal} > property smallest\_normal: [float](https://docs.python.org/3/library/functions.html#float) Alias for `tiny` (`torch.finfo` compatibility). ### `tiny` {#max.dtype.finfo.tiny} > tiny: [float](https://docs.python.org/3/library/functions.html#float) --- ## engine The APIs in this module allow you to run inference with MAX Engine—a graph compiler and runtime that accelerates your AI models on a wide variety of hardware. ## `InferenceSession` {#max.engine.InferenceSession} > class max.engine.InferenceSession(devices, num\_threads=None, \*, custom\_extensions=None) Manages an inference session in which you can load and run models. You need an instance of this to load a model as a [`Model`](#max.engine.Model) object. For example: ```python session = engine.InferenceSession(devices=[CPU()]) model_path = Path('bert-base-uncased') model = session.load(model_path) ```
**Parameters:**
* devices (Iterable\[[Device](driver.md#max.driver.Device)]) * num\_threads ([int](https://docs.python.org/3/library/functions.html#int) | None) * custom\_extensions (CustomExtensionsType | None)
### `devices` {#max.engine.InferenceSession.devices} > property devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](driver.md#max.driver.Device)] A list of available devices. ### `gpu_profiling()` {#max.engine.InferenceSession.gpu_profiling} > gpu\_profiling(mode) Enables GPU profiling instrumentation for the session. This enables GPU profiling instrumentation that works with NVIDIA Nsight Systems and Nsight Compute. When enabled, the runtime adds CUDA driver calls and NVTX markers that allow profiling tools to correlate GPU kernel executions with host-side code. For example, to enable detailed profiling for Nsight Systems analysis, call `gpu_profiling()` before `load()`: ```python from max.engine import InferenceSession, GPUProfilingMode from max.driver import Accelerator session = InferenceSession(devices=[Accelerator()]) session.gpu_profiling(GPUProfilingMode.DETAILED) model = session.load(my_graph) ``` Then run it with `nsys`: ```bash nsys profile --trace=cuda,nvtx python example.py ``` Or, instead of calling `session.gpu_profiling()` in the code, you can set the `MODULAR_ENABLE_PROFILING` environment variable when you call `nsys profile`: ```bash MODULAR_ENABLE_PROFILING=detailed nsys profile --trace=cuda,nvtx python script.py ``` Beware that `gpu_profiling()` overrides the `MODULAR_ENABLE_PROFILING` environment variable if also used. :::note Note Profiling instrumentation adds runtime overhead and should be disabled for production deployments. :::
**Parameters:**
mode ([GPUProfilingMode](#max.engine.GPUProfilingMode)) – The profiling mode to set. One of: * [`GPUProfilingMode.OFF`](#max.engine.GPUProfilingMode.OFF): Disable profiling (default). * [`GPUProfilingMode.ON`](#max.engine.GPUProfilingMode.ON): Enable basic profiling with NVTX markers for kernel correlation. * [`GPUProfilingMode.DETAILED`](#max.engine.GPUProfilingMode.DETAILED): Enable detailed profiling with additional Python-level NVTX markers.
**Return type:**
None
:::note See also * [GPU profiling with Nsight Systems](/max/gpu-system-profiling) ::: ### `load()` {#max.engine.InferenceSession.load} > load(model, \*, custom\_extensions=None, weights\_registry=None) Loads a trained model and compiles it for inference.
**Parameters:**
* model ([str](https://docs.python.org/3/library/stdtypes.html#str) | Path | [Graph](graph/Graph.md#max.graph.Graph)) – Path to a model. * custom\_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to .mojopkg custom ops. * weights\_registry (Mapping\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](driver.md#max.driver.DLPackArray)] | None) – A mapping from names of model weights’ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model.
**Returns:**
The loaded model, compiled and ready to execute.
**Raises:**
[RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the path provided is invalid.
**Return type:**
[Model](#max.engine.Model)
### `set_mojo_assert_level()` {#max.engine.InferenceSession.set_mojo_assert_level} > set\_mojo\_assert\_level(level) Sets which mojo asserts are kept in the compiled model.
**Parameters:**
level (AssertLevel)
**Return type:**
None
### `set_mojo_log_level()` {#max.engine.InferenceSession.set_mojo_log_level} > set\_mojo\_log\_level(level) Sets the verbosity of mojo logging in the compiled model.
**Parameters:**
level ([str](https://docs.python.org/3/library/stdtypes.html#str) | [LogLevel](#max.engine.LogLevel))
**Return type:**
None
### `set_split_k_reduction_precision()` {#max.engine.InferenceSession.set_split_k_reduction_precision} > set\_split\_k\_reduction\_precision(precision) Sets the accumulation precision for split k reductions in large matmuls.
**Parameters:**
precision ([str](https://docs.python.org/3/library/stdtypes.html#str) | SplitKReductionPrecision)
**Return type:**
None
### `use_old_top_k_kernel()` {#max.engine.InferenceSession.use_old_top_k_kernel} > use\_old\_top\_k\_kernel(mode) Enables the old top-k kernel. Default is to use the new top-k kernel to keep it consistent with max/kernels/src/nn/topk.mojo
**Parameters:**
mode ([str](https://docs.python.org/3/library/stdtypes.html#str)) – String to enable/disable. Accepts “false”, “off”, “no”, “0” to disable, any other value to enable.
**Return type:**
None
## `Model` {#max.engine.Model} > class max.engine.Model A loaded model that you can execute. Do not instantiate this class directly. Instead, create it with [`InferenceSession`](#max.engine.InferenceSession). ### `__call__()` {#max.engine.Model.__call} > \_\_call\_\_(\*args, \*\*kwargs) Call self as a function.
**Parameters:**
* self ([Model](#max.engine.Model)) * args ([DLPackArray](driver.md#max.driver.DLPackArray) | [Buffer](driver.md#max.driver.Buffer) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [bool](https://docs.python.org/3/library/functions.html#bool) | [generic](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic)) * kwargs ([DLPackArray](driver.md#max.driver.DLPackArray) | [Buffer](driver.md#max.driver.Buffer) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [bool](https://docs.python.org/3/library/functions.html#bool) | [generic](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](driver.md#max.driver.Buffer)]
### `capture()` {#max.engine.Model.capture} > capture(\*inputs) Capture execution into a device graph keyed by input shapes/dtypes. Capture is best-effort and model-dependent. If the model issues capture-unsafe operations (for example, host-device synchronization), graph capture may fail. Callers should choose capture-safe execution paths.
**Parameters:**
* self ([Model](#max.engine.Model)) * inputs ([Buffer](driver.md#max.driver.Buffer))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](driver.md#max.driver.Buffer)]
### `execute()` {#max.engine.Model.execute} > execute(\*args)
**Parameters:**
* self ([Model](#max.engine.Model)) * args ([DLPackArray](driver.md#max.driver.DLPackArray) | [Buffer](driver.md#max.driver.Buffer) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [bool](https://docs.python.org/3/library/functions.html#bool) | [generic](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](driver.md#max.driver.Buffer)]
### `input_metadata` {#max.engine.Model.input_metadata} > property input\_metadata Metadata about the model’s input tensors, as a list of [`TensorSpec`](#max.engine.TensorSpec) objects. For example, you can print the input tensor names, shapes, and dtypes: ```python for tensor in model.input_metadata: print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}') ``` ### `output_metadata` {#max.engine.Model.output_metadata} > property output\_metadata Metadata about the model’s output tensors, as a list of [`TensorSpec`](#max.engine.TensorSpec) objects. For example, you can print the output tensor names, shapes, and dtypes: ```python for tensor in model.output_metadata: print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}') ``` ### `replay()` {#max.engine.Model.replay} > replay(\*inputs) Replay the captured device graph for these inputs.
**Parameters:**
* self ([Model](#max.engine.Model)) * inputs ([Buffer](driver.md#max.driver.Buffer))
**Return type:**
None
## `GPUProfilingMode` {#max.engine.GPUProfilingMode} > class max.engine.GPUProfilingMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) The supported modes for GPU profiling. GPU profiling modes control the level of instrumentation when profiling MAX applications with NVIDIA Nsight Systems or Nsight Compute. Higher levels provide more detail but may introduce additional overhead. :::note See also [`InferenceSession.gpu_profiling()`](#max.engine.InferenceSession.gpu_profiling): Method to set the profiling mode. ::: ### `DETAILED` {#max.engine.GPUProfilingMode.DETAILED} > DETAILED = 'detailed' Enable detailed GPU profiling with additional NVTX markers from Python code. This mode provides the most visibility into which Python operations correspond to which GPU kernels, but has the highest overhead. ### `OFF` {#max.engine.GPUProfilingMode.OFF} > OFF = 'off' Disable GPU profiling instrumentation. This is the default mode and incurs no profiling overhead. ### `ON` {#max.engine.GPUProfilingMode.ON} > ON = 'on' Enable basic GPU profiling. Adds CUDA driver calls and NVTX markers for correlating kernel executions with host-side code. ## `LogLevel` {#max.engine.LogLevel} > class max.engine.LogLevel(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) The LogLevel specifies the log level used by the Mojo Ops. ### `CRITICAL` {#max.engine.LogLevel.CRITICAL} > CRITICAL = 'critical' ### `DEBUG` {#max.engine.LogLevel.DEBUG} > DEBUG = 'debug' ### `ERROR` {#max.engine.LogLevel.ERROR} > ERROR = 'error' ### `INFO` {#max.engine.LogLevel.INFO} > INFO = 'info' ### `NOTSET` {#max.engine.LogLevel.NOTSET} > NOTSET = 'notset' ### `TRACE` {#max.engine.LogLevel.TRACE} > TRACE = 'trace' ### `WARNING` {#max.engine.LogLevel.WARNING} > WARNING = 'warning' ## `TensorSpec` {#max.engine.TensorSpec} > class max.engine.TensorSpec Defines the properties of a tensor, including its name, shape and data type. For usage examples, see [`Model.input_metadata`](#max.engine.Model.input_metadata). ### `dtype` {#max.engine.TensorSpec.dtype} > property dtype A tensor data type. ### `name` {#max.engine.TensorSpec.name} > property name A tensor name. ### `shape` {#max.engine.TensorSpec.shape} > property shape The shape of the tensor as a list of integers. If a dimension size is unknown/dynamic (such as the batch size), its value is `None`. ## `CustomExtensionsType` {#max.engine.CustomExtensionsType} > max.engine.CustomExtensionsType = collections.abc.Sequence\[str | pathlib.\_local.Path] | str | pathlib.\_local.Path Represent a PEP 604 union type E.g. for int | str --- ## entrypoints ## `LLM` {#max.entrypoints.llm.LLM} > class max.entrypoints.llm.LLM(pipeline\_config) A high level interface for interacting with LLMs.
**Parameters:**
pipeline\_config ([PipelineConfig](pipelines/config.md#max.pipelines.lib.config.PipelineConfig))
### `generate()` {#max.entrypoints.llm.LLM.generate} > generate(prompts, max\_new\_tokens=100, use\_tqdm=True) Generates text completions for the given prompts. This method is thread safe and may be used on the same LLM instance from multiple threads concurrently with no external synchronization.
**Parameters:**
* prompts ([str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – The input string or list of strings to generate completions for. * max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum number of tokens to generate in the response. * use\_tqdm ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to display a progress bar during generation.
**Returns:**
A list of generated text completions corresponding to each input prompt.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If prompts is empty or contains invalid data. * [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the model fails to generate completions.
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]
--- ## functional Provides functional APIs for tensor operations. This module provides functional-style tensor operations that work seamlessly with both MAX Graph construction and eager Tensor execution. All operations are wrapped versions of the core graph operations that automatically handle different execution contexts. These operations can be used in both graph construction and eager execution. ## `CustomExtensionType` {#max.functional.CustomExtensionType} > max.functional.CustomExtensionType: [TypeAlias](https://docs.python.org/3/library/typing.html#typing.TypeAlias) = str | pathlib.\_local.Path Type alias for custom extensions paths, matching `engine.CustomExtensionsType`. ## `abs()` {#max.functional.abs} > max.functional.abs(x) Computes the absolute value element-wise. See [`max.graph.ops.abs()`](graph/ops.md#max.graph.ops.abs) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `add()` {#max.functional.add} > max.functional.add(lhs, rhs) Adds two tensors element-wise. See [`max.graph.ops.add()`](graph/ops.md#max.graph.ops.add) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `allgather()` {#max.functional.allgather} > max.functional.allgather(inputs, signal\_buffers, axis=0) Concatenate values from multiple devices. See [`max.graph.ops.allgather()`](graph/ops.md#max.graph.ops.allgather) for details.
**Parameters:**
* inputs ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)]) * signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](graph/BufferValue.md#max.graph.BufferValue) | HasBufferValue]) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue)]
## `allreduce_sum()` {#max.functional.allreduce_sum} > max.functional.allreduce\_sum(inputs, signal\_buffers) Sum values from multiple devices. See `max.graph.ops.allreduce.sum()` for details.
**Parameters:**
* inputs ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)]) * signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](graph/BufferValue.md#max.graph.BufferValue) | HasBufferValue])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue)]
## `arange()` {#max.functional.arange} > max.functional.arange(start, stop, step=1, out\_dim=None, \*, dtype, device) Creates a tensor with evenly spaced values. See [`max.graph.ops.range()`](graph/ops.md#max.graph.ops.range) for details.
**Parameters:**
* start (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * stop (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * step (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None) * dtype ([DType](dtype.md#max.dtype.DType)) * device ([Device](driver.md#max.driver.Device) | [DeviceRef](graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `argmax()` {#max.functional.argmax} > max.functional.argmax(x, axis=-1) Returns the indices of the maximum values along an axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the maximum indices. If None, finds the index of the maximum across all elements (flattened).
**Returns:**
A tensor containing the indices of the maximum values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `argmin()` {#max.functional.argmin} > max.functional.argmin(x, axis=-1) Returns the indices of the minimum values along an axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the minimum indices. If None, finds the index of the minimum across all elements (flattened).
**Returns:**
A tensor containing the indices of the minimum values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `argsort()` {#max.functional.argsort} > max.functional.argsort(x, ascending=True) Returns the indices that would sort a tensor along an axis. See [`max.graph.ops.argsort()`](graph/ops.md#max.graph.ops.argsort) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue) * ascending ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `as_interleaved_complex()` {#max.functional.as_interleaved_complex} > max.functional.as\_interleaved\_complex(x) Converts a tensor to interleaved complex representation. See [`max.graph.ops.as_interleaved_complex()`](graph/ops.md#max.graph.ops.as_interleaved_complex) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `atanh()` {#max.functional.atanh} > max.functional.atanh(x) Computes the inverse hyperbolic tangent element-wise. See [`max.graph.ops.atanh()`](graph/ops.md#max.graph.ops.atanh) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `avg_pool2d()` {#max.functional.avg_pool2d} > max.functional.avg\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False, count\_boundary=True) Applies 2D average pooling. See [`max.graph.ops.avg_pool2d()`](graph/ops.md#max.graph.ops.avg_pool2d) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) * count\_boundary ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `band_part()` {#max.functional.band_part} > max.functional.band\_part(x, num\_lower=None, num\_upper=None, exclude=False) Copies a tensor setting everything outside a central band to zero. See [`max.graph.ops.band_part()`](graph/ops.md#max.graph.ops.band_part) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * num\_lower ([int](https://docs.python.org/3/library/functions.html#int) | None) * num\_upper ([int](https://docs.python.org/3/library/functions.html#int) | None) * exclude ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `broadcast_to()` {#max.functional.broadcast_to} > max.functional.broadcast\_to(x, shape, out\_dims=None) Broadcasts a tensor to a new shape. See [`max.graph.ops.broadcast_to()`](graph/ops.md#max.graph.ops.broadcast_to) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue) * shape ([TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * out\_dims ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None)
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `buffer_store()` {#max.functional.buffer_store} > max.functional.buffer\_store(destination, source) Sets a tensor buffer to new values. See [`max.graph.ops.buffer_store()`](graph/ops.md#max.graph.ops.buffer_store) for details.
**Parameters:**
* destination ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | HasBufferValue) * source (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
None
## `buffer_store_slice()` {#max.functional.buffer_store_slice} > max.functional.buffer\_store\_slice(destination, source, indices) Sets a slice of a tensor buffer to new values. See [`max.graph.ops.buffer_store_slice()`](graph/ops.md#max.graph.ops.buffer_store_slice) for details.
**Parameters:**
* destination ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | HasBufferValue) * source (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * indices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [int](https://docs.python.org/3/library/functions.html#int) | [slice](https://docs.python.org/3/library/functions.html#slice) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[slice](https://docs.python.org/3/library/functions.html#slice), [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | builtins.ellipsis])
**Return type:**
None
## `cast()` {#max.functional.cast} > max.functional.cast(x, dtype) Casts a tensor to a different data type. See [`max.graph.ops.cast()`](graph/ops.md#max.graph.ops.cast) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue) * dtype ([DType](dtype.md#max.dtype.DType))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `chunk()` {#max.functional.chunk} > max.functional.chunk(x, chunks, axis=0) Splits a tensor into chunks along a dimension. See [`max.graph.ops.chunk()`](graph/ops.md#max.graph.ops.chunk) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * chunks ([int](https://docs.python.org/3/library/functions.html#int)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue)]
## `complex_mul()` {#max.functional.complex_mul} > max.functional.complex\_mul(lhs, rhs) Multiply two complex-valued tensors. See `max.graph.ops.complex.mul()` for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `concat()` {#max.functional.concat} > max.functional.concat(original\_vals, axis=0) Concatenates a list of tensors along an axis. See [`max.graph.ops.concat()`](graph/ops.md#max.graph.ops.concat) for details.
**Parameters:**
* original\_vals ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)]) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `constant()` {#max.functional.constant} > max.functional.constant(value, dtype=None, device=None) Creates a constant tensor. See [`max.graph.ops.constant()`](graph/ops.md#max.graph.ops.constant) for details.
**Parameters:**
* value ([DLPackArray](driver.md#max.driver.DLPackArray) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Number | NestedArray]] | [float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) * dtype ([DType](dtype.md#max.dtype.DType) | None) * device ([Device](driver.md#max.driver.Device) | [DeviceRef](graph/type.md#max.graph.type.DeviceRef) | None)
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `constant_external()` {#max.functional.constant_external} > max.functional.constant\_external(name, type) Creates a constant tensor from external data. See [`max.graph.ops.constant_external()`](graph/ops.md#max.graph.ops.constant_external) for details.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * type ([TensorType](graph/type.md#max.graph.type.TensorType))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `conv2d()` {#max.functional.conv2d} > max.functional.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Applies 2D convolution. See [`max.graph.ops.conv2d()`](graph/ops.md#max.graph.ops.conv2d) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * filter (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * groups ([int](https://docs.python.org/3/library/functions.html#int)) * bias (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) * input\_layout ([ConvInputLayout](graph/type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](graph/type.md#max.graph.type.FilterLayout))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `conv2d_transpose()` {#max.functional.conv2d_transpose} > max.functional.conv2d\_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output\_paddings=(0, 0), bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Applies 2D transposed convolution. See [`max.graph.ops.conv2d_transpose()`](graph/ops.md#max.graph.ops.conv2d_transpose) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * filter (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * output\_paddings ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * bias (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) * input\_layout ([ConvInputLayout](graph/type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](graph/type.md#max.graph.type.FilterLayout))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `conv3d()` {#max.functional.conv3d} > max.functional.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.QRSCF) Applies 3D convolution. See [`max.graph.ops.conv3d()`](graph/ops.md#max.graph.ops.conv3d) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * filter (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * groups ([int](https://docs.python.org/3/library/functions.html#int)) * bias (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) * input\_layout ([ConvInputLayout](graph/type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](graph/type.md#max.graph.type.FilterLayout))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `cos()` {#max.functional.cos} > max.functional.cos(x) Computes the cosine element-wise. See [`max.graph.ops.cos()`](graph/ops.md#max.graph.ops.cos) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `cumsum()` {#max.functional.cumsum} > max.functional.cumsum(x, axis=-1, exclusive=False, reverse=False) Computes the cumulative sum along an axis. See [`max.graph.ops.cumsum()`](graph/ops.md#max.graph.ops.cumsum) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int)) * exclusive ([bool](https://docs.python.org/3/library/functions.html#bool)) * reverse ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `custom()` {#max.functional.custom} > max.functional.custom(name, device, values, out\_types, parameters=None, custom\_extensions=None) Applies a custom operation with optional custom extension loading. Creates a node to execute a custom graph operation. The custom op should be registered by annotating a Mojo function with the `@compiler.register` decorator. This function extends [`max.graph.ops.custom()`](graph/ops.md#max.graph.ops.custom) with automatic loading of custom extension libraries, eliminating the need to manually import kernels before use. **Example:** ```python from max import functional as F, Tensor from max.dtype import DType from max.driver import CPU x = Tensor.full([10], 10, dtype=DType.float32, device=CPU()) y = Tensor.ones([10], dtype=DType.float32, device=CPU()) result = F.custom( "vector_sum", device=x.device, values=[x, y], out_types=[x.type], custom_extensions="ops.mojopkg" )[0] ```
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`. * device ([driver.Device](driver.md#max.driver.Device) | [DeviceRef](graph/ops.md#max.graph.ops.DeviceRef)) – Device that the op is assigned to. This becomes a `target` parameter to the kernel. * values (Sequence\[[Value](graph/Value.md#max.graph.Value)\[Any]]) – The op function’s arguments. * out\_types (Sequence\[[Type](graph/type.md#max.graph.type.Type)\[Any]]) – The list of op function’s return types. * parameters (Mapping\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel. * custom\_extensions (CustomExtensionsType | None) – Paths to custom extension libraries (`.mojopkg` files or Mojo source directories). Extensions are automatically loaded into the current graph if not already present.
**Returns:**
Symbolic values representing the outputs of the op in the graph. These correspond 1:1 with the types passed as `out_types`.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](graph/Value.md#max.graph.Value)\[Any]]
:::note See also [`max.graph.ops.custom()`](graph/ops.md#max.graph.ops.custom): The underlying graph operation. [`inplace_custom()`](#max.functional.inplace_custom): For in-place custom operations. ::: ## `div()` {#max.functional.div} > max.functional.div(lhs, rhs) Divides two tensors element-wise. See [`max.graph.ops.div()`](graph/ops.md#max.graph.ops.div) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `equal()` {#max.functional.equal} > max.functional.equal(lhs, rhs) Computes element-wise equality comparison. See [`max.graph.ops.equal()`](graph/ops.md#max.graph.ops.equal) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `erf()` {#max.functional.erf} > max.functional.erf(x) Computes the error function element-wise. See [`max.graph.ops.erf()`](graph/ops.md#max.graph.ops.erf) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `exp()` {#max.functional.exp} > max.functional.exp(x) Computes the exponential element-wise. See [`max.graph.ops.exp()`](graph/ops.md#max.graph.ops.exp) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `flatten()` {#max.functional.flatten} > max.functional.flatten(x, start\_dim=0, end\_dim=-1) Flattens a tensor. See [`max.graph.ops.flatten()`](graph/ops.md#max.graph.ops.flatten) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * start\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * end\_dim ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `floor()` {#max.functional.floor} > max.functional.floor(x) Computes the floor element-wise. See [`max.graph.ops.floor()`](graph/ops.md#max.graph.ops.floor) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `fold()` {#max.functional.fold} > max.functional.fold(input, output\_size, kernel\_size, stride=1, dilation=1, padding=0) Performs tensor folding operation. See [`max.graph.ops.fold()`](graph/ops.md#max.graph.ops.fold) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * output\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `functional()` {#max.functional.functional} > max.functional.functional(op) Decorator that converts a graph operation to support multiple tensor types.
**Parameters:**
op ([Callable](graph/ops.md#max.graph.ops.Callable)\[\[...], [Any](https://docs.python.org/3/library/typing.html#typing.Any)])
**Return type:**
[Callable](graph/ops.md#max.graph.ops.Callable)\[\[…], [Any](https://docs.python.org/3/library/typing.html#typing.Any)]
## `gather()` {#max.functional.gather} > max.functional.gather(input, indices, axis) Gathers values along an axis specified by indices. See [`max.graph.ops.gather()`](graph/ops.md#max.graph.ops.gather) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * indices (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `gather_nd()` {#max.functional.gather_nd} > max.functional.gather\_nd(input, indices, batch\_dims=0) Gathers values using multi-dimensional indices. See [`max.graph.ops.gather_nd()`](graph/ops.md#max.graph.ops.gather_nd) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * indices (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * batch\_dims ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `gelu()` {#max.functional.gelu} > max.functional.gelu(x, approximate='none') Applies the Gaussian Error Linear Unit (GELU) activation. See [`max.graph.ops.gelu()`](graph/ops.md#max.graph.ops.gelu) for details.
**Parameters:**
* x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue)) * approximate ([str](https://docs.python.org/3/library/stdtypes.html#str))
## `greater()` {#max.functional.greater} > max.functional.greater(lhs, rhs) Computes element-wise greater-than comparison. See [`max.graph.ops.greater()`](graph/ops.md#max.graph.ops.greater) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `greater_equal()` {#max.functional.greater_equal} > max.functional.greater\_equal(lhs, rhs) Computes element-wise greater-than-or-equal comparison. See [`max.graph.ops.greater_equal()`](graph/ops.md#max.graph.ops.greater_equal) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `hann_window()` {#max.functional.hann_window} > max.functional.hann\_window(window\_length, device, periodic=True, dtype=float32) Creates a Hann window. See [`max.graph.ops.hann_window()`](graph/ops.md#max.graph.ops.hann_window) for details.
**Parameters:**
* window\_length ([int](https://docs.python.org/3/library/functions.html#int)) * device ([DeviceRef](graph/type.md#max.graph.type.DeviceRef)) * periodic ([bool](https://docs.python.org/3/library/functions.html#bool)) * dtype ([DType](dtype.md#max.dtype.DType))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `in_graph_context()` {#max.functional.in_graph_context} > max.functional.in\_graph\_context() Checks whether the caller is inside a Graph context.
**Returns:**
True if inside a `with Graph(...):` block, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `inplace_custom()` {#max.functional.inplace_custom} > max.functional.inplace\_custom(name, device, values, out\_types=None, parameters=None, custom\_extensions=None) Applies an in-place custom operation with optional custom extension loading. Creates a node to execute an in-place custom graph operation. The custom op should be registered by annotating a Mojo function with the `@compiler.register` decorator. This function extends [`max.graph.ops.inplace_custom()`](graph/ops.md#max.graph.ops.inplace_custom) with automatic loading of custom extension libraries, eliminating the need to manually import kernels before use. **Example:** ```python from max import functional as F, Tensor from max.dtype import DType from max.driver import CPU # Create a buffer for in-place modification data = Tensor.zeros([10], dtype=DType.float32, device=CPU()) # Use in-place custom op with inline extension loading F.inplace_custom( "my_inplace_op", device=data.device, values=[data], custom_extensions="ops.mojopkg" ) ```
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`. * device ([driver.Device](driver.md#max.driver.Device) | [DeviceRef](graph/ops.md#max.graph.ops.DeviceRef)) – Device that the op is assigned to. This becomes a `target` parameter to the kernel. * values (Sequence\[[Value](graph/Value.md#max.graph.Value)\[Any]]) – The op function’s arguments. At least one must be a `BufferValue` or `_OpaqueValue`. * out\_types (Sequence\[[Type](graph/type.md#max.graph.type.Type)\[Any]] | None) – The list of op function’s return types. Can be None if the operation has no outputs. * parameters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel. * custom\_extensions (CustomExtensionsType | None) – Paths to custom extension libraries (`.mojopkg` files or Mojo source directories). Extensions are automatically loaded into the current graph if not already present.
**Returns:**
Symbolic values representing the outputs of the op in the graph.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](graph/Value.md#max.graph.Value)\[Any]]
:::note See also [`max.graph.ops.inplace_custom()`](graph/ops.md#max.graph.ops.inplace_custom): The underlying graph operation. [`custom()`](#max.functional.custom): For non-in-place custom operations. ::: ## `irfft()` {#max.functional.irfft} > max.functional.irfft(input\_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input\_is\_complex=False, buffer\_size\_mb=512) Computes the inverse real FFT. See [`max.graph.ops.irfft()`](graph/ops.md#max.graph.ops.irfft) for details.
**Parameters:**
* input\_tensor (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue) * n ([int](https://docs.python.org/3/library/functions.html#int) | None) * axis ([int](https://docs.python.org/3/library/functions.html#int)) * normalization (Normalization | [str](https://docs.python.org/3/library/stdtypes.html#str)) * input\_is\_complex ([bool](https://docs.python.org/3/library/functions.html#bool)) * buffer\_size\_mb ([int](https://docs.python.org/3/library/functions.html#int))
## `is_inf()` {#max.functional.is_inf} > max.functional.is\_inf(x) Checks for infinite values element-wise. See [`max.graph.ops.is_inf()`](graph/ops.md#max.graph.ops.is_inf) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `is_nan()` {#max.functional.is_nan} > max.functional.is\_nan(x) Checks for NaN values element-wise. See [`max.graph.ops.is_nan()`](graph/ops.md#max.graph.ops.is_nan) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `layer_norm()` {#max.functional.layer_norm} > max.functional.layer\_norm(input, gamma, beta, epsilon) Applies layer normalization. See [`max.graph.ops.layer_norm()`](graph/ops.md#max.graph.ops.layer_norm) for details.
**Parameters:**
* input ([TensorValue](graph/TensorValue.md#max.graph.TensorValue)) * gamma (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * beta (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * epsilon ([float](https://docs.python.org/3/library/functions.html#float))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `lazy()` {#max.functional.lazy} > max.functional.lazy() Context manager for lazy tensor evaluation. Within this context, tensor operations are recorded but not executed. Tensors remain unrealized until explicitly awaited via `await tensor.realize` or until their values are needed (e.g., by calling `.item()`). This is particularly useful for creating tensors which may not ever be used. Lazy tensors that aren’t used will never allocate memory or perform operations.
**Yields:**
None
**Return type:**
[Generator](https://docs.python.org/3/library/collections.abc.html#collections.abc.Generator)\[None]
```python from max import functional as F from max.tensor import Tensor from max.nn import Linear with F.lazy(): model = Linear(2, 3) print(model) # Lazy weights not initialized # Executing the model would be fine! The weights would be created # on first use. # output = model(Tensor.ones([5, 2])) # Load pretrained weights, never creating the original random weights weights = { "weight": Tensor.zeros([3, 2]), "bias": Tensor.zeros([3]), } model.load_state_dict(weights) ``` ## `log()` {#max.functional.log} > max.functional.log(x) Computes the natural logarithm element-wise. See [`max.graph.ops.log()`](graph/ops.md#max.graph.ops.log) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `log1p()` {#max.functional.log1p} > max.functional.log1p(x) Computes log(1 + x) element-wise. See [`max.graph.ops.log1p()`](graph/ops.md#max.graph.ops.log1p) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `logical_and()` {#max.functional.logical_and} > max.functional.logical\_and(lhs, rhs) Computes element-wise logical AND. See [`max.graph.ops.logical_and()`](graph/ops.md#max.graph.ops.logical_and) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `logical_not()` {#max.functional.logical_not} > max.functional.logical\_not(x) Computes element-wise logical NOT. See [`max.graph.ops.logical_not()`](graph/ops.md#max.graph.ops.logical_not) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `logical_or()` {#max.functional.logical_or} > max.functional.logical\_or(lhs, rhs) Computes element-wise logical OR. See [`max.graph.ops.logical_or()`](graph/ops.md#max.graph.ops.logical_or) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `logical_xor()` {#max.functional.logical_xor} > max.functional.logical\_xor(lhs, rhs) Computes element-wise logical XOR. See [`max.graph.ops.logical_xor()`](graph/ops.md#max.graph.ops.logical_xor) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `logsoftmax()` {#max.functional.logsoftmax} > max.functional.logsoftmax(value, axis=-1) Applies the log softmax function. See [`max.graph.ops.logsoftmax()`](graph/ops.md#max.graph.ops.logsoftmax) for details.
**Parameters:**
* value (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `masked_scatter()` {#max.functional.masked_scatter} > max.functional.masked\_scatter(input, mask, updates, out\_dim) Scatters values according to a mask. See [`max.graph.ops.masked_scatter()`](graph/ops.md#max.graph.ops.masked_scatter) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * mask (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * updates (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `matmul()` {#max.functional.matmul} > max.functional.matmul(lhs, rhs) Performs matrix multiplication. See [`max.graph.ops.matmul()`](graph/ops.md#max.graph.ops.matmul) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `max()` {#max.functional.max} > max.functional.max(x, y=None, /, axis=-1) Returns the maximum values along an axis, or elementwise maximum of two tensors.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * y (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – Optional second tensor for elementwise maximum. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the maximum (only for reduction). If None, computes the maximum across all elements (flattened).
**Returns:**
A tensor containing the maximum values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `max_pool2d()` {#max.functional.max_pool2d} > max.functional.max\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False) Applies 2D max pooling. See [`max.graph.ops.max_pool2d()`](graph/ops.md#max.graph.ops.max_pool2d) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `mean()` {#max.functional.mean} > max.functional.mean(x, axis=-1) Computes the mean along specified axes.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the mean. If None, computes the mean across all elements (flattened).
**Returns:**
A tensor containing the mean values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `min()` {#max.functional.min} > max.functional.min(x, y=None, /, axis=-1) Returns the minimum values along an axis, or elementwise minimum of two tensors.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * y (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – Optional second tensor for elementwise minimum. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the minimum (only for reduction). If None, computes the minimum across all elements (flattened).
**Returns:**
A tensor containing the minimum values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `mod()` {#max.functional.mod} > max.functional.mod(lhs, rhs) Computes the modulo operation element-wise. See [`max.graph.ops.mod()`](graph/ops.md#max.graph.ops.mod) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `mul()` {#max.functional.mul} > max.functional.mul(lhs, rhs) Multiplies two tensors element-wise. See [`max.graph.ops.mul()`](graph/ops.md#max.graph.ops.mul) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `negate()` {#max.functional.negate} > max.functional.negate(x) Negates a tensor element-wise. See [`max.graph.ops.negate()`](graph/ops.md#max.graph.ops.negate) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `nonzero()` {#max.functional.nonzero} > max.functional.nonzero(x, out\_dim) Returns the indices of non-zero elements. See [`max.graph.ops.nonzero()`](graph/ops.md#max.graph.ops.nonzero) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `not_equal()` {#max.functional.not_equal} > max.functional.not\_equal(lhs, rhs) Computes element-wise inequality comparison. See [`max.graph.ops.not_equal()`](graph/ops.md#max.graph.ops.not_equal) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `outer()` {#max.functional.outer} > max.functional.outer(lhs, rhs) Computes the outer product of two vectors. See [`max.graph.ops.outer()`](graph/ops.md#max.graph.ops.outer) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `pad()` {#max.functional.pad} > max.functional.pad(input, paddings, mode='constant', value=0) Pads a tensor. See [`max.graph.ops.pad()`](graph/ops.md#max.graph.ops.pad) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * paddings ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int)]) * mode ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['constant']) * value (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `permute()` {#max.functional.permute} > max.functional.permute(x, dims) Permutes the dimensions of a tensor. See [`max.graph.ops.permute()`](graph/ops.md#max.graph.ops.permute) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `pow()` {#max.functional.pow} > max.functional.pow(lhs, rhs) Raises tensor elements to a power. See [`max.graph.ops.pow()`](graph/ops.md#max.graph.ops.pow) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `relu()` {#max.functional.relu} > max.functional.relu(x) Applies the ReLU activation function. See [`max.graph.ops.relu()`](graph/ops.md#max.graph.ops.relu) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `repeat_interleave()` {#max.functional.repeat_interleave} > max.functional.repeat\_interleave(x, repeats, axis=None, out\_dim=None) Repeats elements of a tensor. See [`max.graph.ops.repeat_interleave()`](graph/ops.md#max.graph.ops.repeat_interleave) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * repeats ([int](https://docs.python.org/3/library/functions.html#int) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None)
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `reshape()` {#max.functional.reshape} > max.functional.reshape(x, shape) Reshapes a tensor to a new shape. See [`max.graph.ops.reshape()`](graph/ops.md#max.graph.ops.reshape) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `round()` {#max.functional.round} > max.functional.round(x) Rounds tensor values element-wise. See [`max.graph.ops.round()`](graph/ops.md#max.graph.ops.round) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `rsqrt()` {#max.functional.rsqrt} > max.functional.rsqrt(x) Computes the reciprocal square root element-wise. See [`max.graph.ops.rsqrt()`](graph/ops.md#max.graph.ops.rsqrt) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `scatter()` {#max.functional.scatter} > max.functional.scatter(input, updates, indices, axis=-1) Scatters values along an axis. See [`max.graph.ops.scatter()`](graph/ops.md#max.graph.ops.scatter) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * updates (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * indices (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `scatter_nd()` {#max.functional.scatter_nd} > max.functional.scatter\_nd(input, updates, indices) Scatters values using multi-dimensional indices. See [`max.graph.ops.scatter_nd()`](graph/ops.md#max.graph.ops.scatter_nd) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * updates (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * indices (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `sigmoid()` {#max.functional.sigmoid} > max.functional.sigmoid(x) Applies the sigmoid activation function. See [`max.graph.ops.sigmoid()`](graph/ops.md#max.graph.ops.sigmoid) for details.
**Parameters:**
x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `silu()` {#max.functional.silu} > max.functional.silu(x) Applies the SiLU (Swish) activation function. See [`max.graph.ops.silu()`](graph/ops.md#max.graph.ops.silu) for details.
**Parameters:**
x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
## `sin()` {#max.functional.sin} > max.functional.sin(x) Computes the sine element-wise. See [`max.graph.ops.sin()`](graph/ops.md#max.graph.ops.sin) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `slice_tensor()` {#max.functional.slice_tensor} > max.functional.slice\_tensor(x, indices) Slices a tensor along specified dimensions. See [`max.graph.ops.slice_tensor()`](graph/ops.md#max.graph.ops.slice_tensor) for details.
**Parameters:**
* x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue)) * indices (SliceIndices)
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `softmax()` {#max.functional.softmax} > max.functional.softmax(value, axis=-1) Applies the softmax function. See [`max.graph.ops.softmax()`](graph/ops.md#max.graph.ops.softmax) for details.
**Parameters:**
* value (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `split()` {#max.functional.split} > max.functional.split(x, split\_size\_or\_sections, axis=0) Splits a tensor into multiple tensors along a given dimension. This function supports two modes, matching PyTorch’s behavior: * If `split_size_or_sections` is an **int**, splits into chunks of that size (the last chunk may be smaller if the dimension is not evenly divisible). * If `split_size_or_sections` is a **list of ints**, splits into chunks with exactly those sizes (must sum to the dimension size). ```python from max import functional as F, Tensor x = Tensor.ones([10, 4]) # Split into chunks of size 3 (last chunk is size 1) chunks = F.split(x, 3, axis=0) # shapes: [3,4], [3,4], [3,4], [1,4] # Split into exact sizes chunks = F.split(x, [2, 3, 5], axis=0) # shapes: [2,4], [3,4], [5,4] ```
**Parameters:**
* x ([Tensor](tensor.md#max.tensor.Tensor) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to split. * split\_size\_or\_sections ([int](https://docs.python.org/3/library/functions.html#int) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Either an int (chunk size) or a list of ints (exact sizes for each output tensor). * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension along which to split. Defaults to 0.
**Returns:**
A list of tensors resulting from the split.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Tensor](tensor.md#max.tensor.Tensor)] | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue)]
## `sqrt()` {#max.functional.sqrt} > max.functional.sqrt(x) Computes the square root element-wise. See [`max.graph.ops.sqrt()`](graph/ops.md#max.graph.ops.sqrt) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `squeeze()` {#max.functional.squeeze} > max.functional.squeeze(x, axis) Removes dimensions of size 1. See [`max.graph.ops.squeeze()`](graph/ops.md#max.graph.ops.squeeze) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `stack()` {#max.functional.stack} > max.functional.stack(values, axis=0) Stacks tensors along a new dimension. See [`max.graph.ops.stack()`](graph/ops.md#max.graph.ops.stack) for details.
**Parameters:**
* values ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)]) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `sub()` {#max.functional.sub} > max.functional.sub(lhs, rhs) Subtracts two tensors element-wise. See [`max.graph.ops.sub()`](graph/ops.md#max.graph.ops.sub) for details.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `sum()` {#max.functional.sum} > max.functional.sum(x, axis=-1) Computes the sum along specified axes.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the sum. If None, computes the sum across all elements (flattened).
**Returns:**
A tensor containing the sum values.
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `tanh()` {#max.functional.tanh} > max.functional.tanh(x) Computes the hyperbolic tangent element-wise. See [`max.graph.ops.tanh()`](graph/ops.md#max.graph.ops.tanh) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `tile()` {#max.functional.tile} > max.functional.tile(x, repeats) Tiles a tensor by repeating it. See [`max.graph.ops.tile()`](graph/ops.md#max.graph.ops.tile) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * repeats ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `top_k()` {#max.functional.top_k} > max.functional.top\_k(input, k, axis=-1) Returns the k largest elements along an axis. See [`max.graph.ops.top_k()`](graph/ops.md#max.graph.ops.top_k) for details.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * k ([int](https://docs.python.org/3/library/functions.html#int)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](graph/TensorValue.md#max.graph.TensorValue), [TensorValue](graph/TensorValue.md#max.graph.TensorValue)]
## `transfer_to()` {#max.functional.transfer_to} > max.functional.transfer\_to(x, device) Transfers a tensor to a specified device. See [`max.graph.ops.transfer_to()`](graph/ops.md#max.graph.ops.transfer_to) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue) * device ([Device](driver.md#max.driver.Device) | [DeviceRef](graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `transpose()` {#max.functional.transpose} > max.functional.transpose(x, axis\_1, axis\_2) Transposes a tensor. See [`max.graph.ops.transpose()`](graph/ops.md#max.graph.ops.transpose) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis\_1 ([int](https://docs.python.org/3/library/functions.html#int)) * axis\_2 ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `trunc()` {#max.functional.trunc} > max.functional.trunc(x) Truncates tensor values element-wise. See [`max.graph.ops.trunc()`](graph/ops.md#max.graph.ops.trunc) for details.
**Parameters:**
x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `unsqueeze()` {#max.functional.unsqueeze} > max.functional.unsqueeze(x, axis) Adds dimensions of size 1. See [`max.graph.ops.unsqueeze()`](graph/ops.md#max.graph.ops.unsqueeze) for details.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
## `where()` {#max.functional.where} > max.functional.where(condition, x, y) Selects elements from two tensors based on a condition. See [`max.graph.ops.where()`](graph/ops.md#max.graph.ops.where) for details.
**Parameters:**
* condition (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) * y (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](graph/TensorValue.md#max.graph.TensorValue)
--- ## BufferValue ## `BufferValue` {#max.graph.BufferValue} > class max.graph.BufferValue(value) Bases: [`Value`](Value.md#max.graph.Value)\[`BufferType`] Represents a mutable semantic tensor within a Graph.
**Parameters:**
value ([Value](Value.md#max.graph.Value)\[Any] | \_Value\[mo.BufferType] | HasBufferValue)
### `device` {#max.graph.BufferValue.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) Returns the device of the BufferValue. ### `dtype` {#max.graph.BufferValue.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the tensor data type. ### `from_mlir()` {#max.graph.BufferValue.from_mlir} > classmethod from\_mlir(value) Creates a [`BufferValue`](#max.graph.BufferValue) from an MLIR buffer value.
**Parameters:**
value (Value\[BufferType]) – The MLIR buffer value to wrap.
**Return type:**
[BufferValue](#max.graph.BufferValue)
### `print()` {#max.graph.BufferValue.print} > print(label='debug\_buffer') Prints detailed information about the buffer.
**Parameters:**
label ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
None
### `rank` {#max.graph.BufferValue.rank} > property rank: [int](https://docs.python.org/3/library/functions.html#int) Returns the rank (number of dims) of the buffer. ### `shape` {#max.graph.BufferValue.shape} > property shape: [Shape](shape.md#max.graph.shape.Shape) Returns the shape of the BufferValue. ### `type` {#max.graph.BufferValue.type} > property type: [BufferType](type.md#max.graph.type.BufferType) Returns the type of the [`BufferValue`](#max.graph.BufferValue) as a `BufferType`. --- ## Graph ## `Graph` {#max.graph.Graph} > class max.graph.Graph(name, forward=None, input\_types=(), path=None, \*args, custom\_extensions=\[], kernel\_library=None, module=None, \*\*kwargs) Represents a single MAX graph. A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions. When you instantiate a graph, you must specify the input shapes as one or more `TensorType` values. Then, build a sequence of ops and set the graph output with [`output()`](#max.graph.Graph.output). For example: ```python from dataclasses import dataclass import numpy as np from max.dtype import DType from max.graph import Graph, TensorType, TensorValue, ops @dataclass class Linear: weight: np.ndarray bias: np.ndarray def __call__(self, x: TensorValue) -> TensorValue: weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU()) bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU()) return ops.matmul(x, weight_tensor) + bias_tensor linear_graph = Graph( "linear", Linear(np.ones((2, 2)), np.ones((2,))), input_types=[TensorType(DType.float32, (2,))] ) ``` You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to [build a graph with MAX Graph](/max/tutorials/get-started-with-max-graph-in-python). When creating a graph, a global sequence of chains is initialized and stored in Graph.\_current\_chain. Every side-effecting op, e.g. buffer\_load, store\_buffer, load\_slice\_buffer, store\_slice\_buffer, will use the current chain to perform the op and and update Graph.\_current\_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A name for the graph. * forward ([Callable](ops.md#max.graph.ops.Callable)\[..., None | [Value](Value.md#max.graph.Value)\[Any] | Iterable\[[Value](Value.md#max.graph.Value)\[Any]]] | None) – The sequence of graph ops for the forward pass (inference). * input\_types (Iterable\[[Type](type.md#max.graph.type.Type)\[Any]]) – The data type(s) for the input tensor(s). * path (Path | None) – The path to a saved graph (internal use only). * custom\_extensions (Iterable\[Path]) – The extensions to load for the model. Supports paths to `.mojopkg` or `.mojo` sources with custom ops. * kernel\_library ([KernelLibrary](KernelLibrary.md#max.graph.KernelLibrary) | None) – Optional pre-built kernel library to use. Defaults to `None` (a new library is created from `custom_extensions` if needed). * module (mlir.Module | None) – Optional existing MLIR module (internal use only). Defaults to `None`.
### `add_subgraph()` {#max.graph.Graph.add_subgraph} > add\_subgraph(name, forward=None, input\_types=(), path=None, custom\_extensions=\[], devices=\[]) Creates and adds a subgraph to the current graph. Creates a new [`Graph`](#max.graph.Graph) instance configured as a subgraph of the current graph. The subgraph inherits the parent graph’s module and symbolic parameters. A chain type is automatically appended to the input types to enable proper operation sequencing within the subgraph. The created subgraph is marked with special MLIR attributes to identify it as a subgraph and is registered in the parent graph’s subgraph registry.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name identifier for the subgraph. * forward ([Callable](ops.md#max.graph.ops.Callable)\[\[...], None | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None) – The optional callable that defines the sequence of operations for the subgraph’s forward pass. If provided, the subgraph will be built immediately using this callable. * input\_types ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The data types for the subgraph’s input tensors. A chain type will be automatically added to these input types. * path (Path | None) – The optional path to a saved subgraph definition to load from disk instead of creating a new one. * custom\_extensions ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Path]) – The list of paths to custom operation libraries to load for the subgraph. Supports `.mojopkg` files and Mojo source directories. * devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](type.md#max.graph.type.DeviceRef)]) – The list of devices this subgraph is meant to use.
**Return type:**
[Graph](#max.graph.Graph)
### `add_weight()` {#max.graph.Graph.add_weight} > add\_weight(weight, force\_initial\_weight\_on\_host=True) Adds a weight to the graph. If the weight is in the graph already, return the existing value.
**Parameters:**
* weight ([Weight](Weight.md#max.graph.Weight)) – The weight to add to the graph. * force\_initial\_weight\_on\_host ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants.
**Returns:**
A [`TensorValue`](TensorValue.md#max.graph.TensorValue) that contains this weight.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a weight with the same name already exists in the graph.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `always_ready_chain` {#max.graph.Graph.always_ready_chain} > property always\_ready\_chain: \_ChainValue A graph-global, immutable chain that is always ready. Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (e.g., host→device transfers for staging). ### `current` {#max.graph.Graph.current} > current ### `device_chains` {#max.graph.Graph.device_chains} > device\_chains: \_DeviceChainMap ### `inputs` {#max.graph.Graph.inputs} > property inputs: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] The input values of the graph. ### `kernel_libraries_paths` {#max.graph.Graph.kernel_libraries_paths} > property kernel\_libraries\_paths: [list](https://docs.python.org/3/library/stdtypes.html#list)\[Path] Returns the list of extra kernel libraries paths for the custom ops. ### `output()` {#max.graph.Graph.output} > output(\*outputs) Sets the output nodes of the [`Graph`](#max.graph.Graph).
**Parameters:**
outputs ([Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
None
### `output_types` {#max.graph.Graph.output_types} > property output\_types: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] View of the types of the graph output terminator. --- ## KernelLibrary ## `KernelLibrary` {#max.graph.KernelLibrary} > class max.graph.KernelLibrary(paths=()) Manages custom kernel libraries and operations for a graph. A kernel library provides access to custom operations and kernels that can be loaded from various sources including Mojo binary packages (`.mojopkg`) and Mojo source directories. The library handles verification and registration of custom operations within the MLIR context.
**Parameters:**
paths (Iterable\[Path])
### `add_path()` {#max.graph.KernelLibrary.add_path} > add\_path(path) Adds a kernel library path to the analysis.
**Parameters:**
path (Path) – The `Path` to the kernel library to be added to the current analysis.
**Return type:**
None
### `library_paths()` {#max.graph.KernelLibrary.library_paths} > library\_paths() Returns the list of kernel library paths.
**Returns:**
A list of `Path` objects representing the currently loaded kernel library paths.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]
### `load_paths()` {#max.graph.KernelLibrary.load_paths} > load\_paths(custom\_extensions) Loads custom operations from provided library paths. Performs additional “smart” library loading logic for custom operation libraries in additional formats. The loading logic supports the following formats: * Compiled Mojo binary packages with `.mojopkg` extension * Mojo source directory with custom operations The loaded libraries are added to the current kernel library.
**Parameters:**
custom\_extensions ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Path]) – The file paths to the custom operation libraries.
**Return type:**
None
### `verify_custom_op()` {#max.graph.KernelLibrary.verify_custom_op} > verify\_custom\_op(custom\_op) Verifies that a custom operation is valid within the current context.
**Parameters:**
custom\_op (Operation) – The `mlir.Operation` to be verified against the current kernel library analysis.
**Return type:**
None
--- ## TensorValue ## `TensorValue` {#max.graph.TensorValue} > class max.graph.TensorValue(value) Bases: [`Value`](Value.md#max.graph.Value)\[`TensorType`] Represents a value semantic tensor within a [`Graph`](Graph.md#max.graph.Graph). It provides various methods and properties to manipulate and query tensor attributes such as [`shape`](shape.md#module-max.graph.shape), data type ([`dtype`](#max.graph.TensorValue.dtype)), device placement ([`device`](#max.graph.TensorValue.device)), and more. The following example demonstrates how to create and manipulate tensor values in a graph: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("tensor_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor properties print(f"Shape: {tensor.shape}") # Output: [2, 2] print(f"Data type: {tensor.dtype}") # Output: DType.float32 # Perform operations on the tensor transposed = tensor.T doubled = tensor * 2 print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Transposed shape: {transposed.shape}") # Output: [2, 2] ```
**Parameters:**
value (TensorValueLike)
### `T` {#max.graph.TensorValue.T} > property T: [TensorValue](#max.graph.TensorValue) Returns the transposed tensor. [`T`](#max.graph.TensorValue.T) is the shorthand notation for transposing. For more information, see [`transpose()`](#max.graph.TensorValue.transpose).
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions.
### `argmax()` {#max.graph.TensorValue.argmax} > argmax(axis=-1) Reduces the tensor using an argmax operation along `axis`. When the result is ambiguous ie. there are multiple maxima, selects one index arbitrarily. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("argmax_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Argmax along axis 1 (last dimension of each row) indices = x.argmax(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Argmax shape: {indices.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) of dtype `DType.int64` with the same rank as the input, and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `broadcast_to()` {#max.graph.TensorValue.broadcast_to} > broadcast\_to(shape) Broadcasts the tensor to a new shape. The following example demonstrates how to broadcast a tensor to a larger shape: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("broadcast_to_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Broadcast tensor to a 3x2x2 tensor (add a new dimension of size 3) broadcasted_tensor = tensor.broadcast_to((3, 2, 2)) print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Broadcasted shape: {broadcasted_tensor.shape}") # Output: [3, 2, 2] ```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – An iterable of integers or symbolic dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the broadcasted shape.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `cast()` {#max.graph.TensorValue.cast} > cast(dtype) Casts a symbolic tensor to a different data type. The following example demonstrates how to cast a tensor from one data type to another: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a matrix with float32 values matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("cast_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Cast tensor to integer type casted_tensor = tensor.cast(DType.int32) print(f"Original dtype: {tensor.dtype}") # Output: DType.float32 print(f"Casted dtype: {casted_tensor.dtype}") # Output: DType.int32 ```
**Parameters:**
dtype ([DType](../dtype.md#max.dtype.DType)) – The target data type (e.g., `DType.int32`, `DType.float64`).
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the casted data type.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `device` {#max.graph.TensorValue.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) Returns the device of the TensorValue. ### `dtype` {#max.graph.TensorValue.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the tensor data type. The following example demonstrates how to access the data type of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a matrix with float32 values matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("dtype_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor data type print(f"Data type: {tensor.dtype}") # Output: DType.float32 ``` ### `flatten()` {#max.graph.TensorValue.flatten} > flatten(start\_dim=0, end\_dim=-1) Flattens the specified dims of a symbolic tensor. The number and order of the elements in the tensor is unchanged. All dimensions from `start_dim` to `end_dim` (inclusive) are merged into a single output dim. The following example demonstrates how to flatten a multi-dimensional tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("flatten_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Flatten the tensor to a 1D array flattened_tensor = tensor.flatten() print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Flattened shape: {flattened_tensor.shape}") # Output: [4] ```
**Parameters:**
* start\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – The starting dimension to flatten. Defaults to `0`. * end\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – The ending dimension to flatten. Defaults to `-1`.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the flattened dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `from_mlir()` {#max.graph.TensorValue.from_mlir} > classmethod from\_mlir(value) Creates a [`TensorValue`](#max.graph.TensorValue) from an MLIR tensor value.
**Parameters:**
value (Value\[TensorType]) – The MLIR tensor value to wrap.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `max()` {#max.graph.TensorValue.max} > max(axis=-1) Reduces the tensor using a max operation along `axis`. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("max_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Max along axis 1 (last dimension of each row) m = x.max(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Max shape: {m.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `mean()` {#max.graph.TensorValue.mean} > mean(axis=-1) Reduces the tensor using a mean operation along `axis`. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("mean_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Mean along axis 1 (last dimension of each row) mu = x.mean(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Mean shape: {mu.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `min()` {#max.graph.TensorValue.min} > min(axis=-1) Reduces the tensor using a min operation along `axis`. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("min_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Min along axis 1 (last dimension of each row) mn = x.min(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Min shape: {mn.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `permute()` {#max.graph.TensorValue.permute} > permute(dims) Permutes the tensor’s dimensions based on provided indices.
**Parameters:**
dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – A list of integers specifying the new order of dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with permuted dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `print()` {#max.graph.TensorValue.print} > print(label='debug\_tensor') Prints detailed information about the tensor.
**Parameters:**
label ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A string label for the printed output. Defaults to `debug_tensor`.
**Return type:**
None
### `rank` {#max.graph.TensorValue.rank} > property rank: [int](https://docs.python.org/3/library/functions.html#int) Returns the rank (number of dims) of the buffer. The following example demonstrates how to access the rank of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix (2-dimensional array) matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("rank_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor rank (number of dimensions) print(f"Rank: {tensor.rank}") # Output: 2 ``` ### `rebind()` {#max.graph.TensorValue.rebind} > rebind(shape, message='') Rebinds the tensor to a new shape with error handling.
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as an iterable of integers or symbolic dimensions. * message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – (optional) A message for logging or debugging.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the updated shape.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `reshape()` {#max.graph.TensorValue.reshape} > reshape(shape) Creates a new tensor with the same data but reshaped. The following example demonstrates how to reshape a tensor to change its dimensions: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("reshape_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Reshape tensor to a 1x4 matrix reshaped_tensor = tensor.reshape((1, 4)) print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Reshaped shape: {reshaped_tensor.shape}") # Output: [1, 4] ```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as an iterable of integers or symbolic dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the reshaped dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `shape` {#max.graph.TensorValue.shape} > property shape: [Shape](shape.md#max.graph.shape.Shape) Returns the shape of the [`TensorValue`](#max.graph.TensorValue). The following example demonstrates how to access the shape of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("shape_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor shape print(f"Shape: {tensor.shape}") # Shape: [Dim(2), Dim(2)] ``` ### `stdev()` {#max.graph.TensorValue.stdev} > stdev(axis=-1) Reduces the tensor using a standard deviation operation along `axis`. The standard deviation is computed as the square root of the population variance along the specified axis. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("stdev_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Standard deviation along axis 1 (last dimension of each row) sd = x.stdev(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Stdev shape: {sd.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `to()` {#max.graph.TensorValue.to} > to(device) Transfers the tensor to a specified device without mutation. The following example demonstrates how to move a tensor from one device to another: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops, DeviceRef # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) with Graph("to_device_example") as graph: # Create a tensor on the default device tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Move the tensor to a GPU device gpu_tensor = tensor.to(DeviceRef.GPU()) print(f"Original device: {tensor.device}") # Output depends on default device print(f"New device: {gpu_tensor.device}") # Output: gpu:0 ```
**Parameters:**
device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – A `DeviceRef` object specifying the target device.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) on the specified device.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `transpose()` {#max.graph.TensorValue.transpose} > transpose(dim\_1, dim\_2) Swaps two dimensions of the tensor. The following example demonstrates how to transpose a tensor by swapping its dimensions: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x3 matrix matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) with Graph("transpose_demo") as graph: tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Transpose the tensor (swap dimensions 0 and 1) transposed_tensor = tensor.transpose(dim_1=0, dim_2=1) print(f"Original shape: {tensor.shape}") # Output: [2, 3] print(f"Transposed shape: {transposed_tensor.shape}") # Output: [3, 2] ```
**Parameters:**
* dim\_1 ([int](https://docs.python.org/3/library/functions.html#int)) – The first dimension to swap. * dim\_2 ([int](https://docs.python.org/3/library/functions.html#int)) – The second dimension to swap.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `type` {#max.graph.TensorValue.type} > property type: [TensorType](type.md#max.graph.type.TensorType) Returns the type of the [`TensorValue`](#max.graph.TensorValue) as a `TensorType`. ### `var()` {#max.graph.TensorValue.var} > var(axis=-1) Reduces the tensor using a variance operation along `axis`. The variance is computed as the mean of squared deviations from the mean (population variance, i.e., without Bessel’s correction) along the specified axis. ```python from max.dtype import DType from max.graph import Graph, TensorType, DeviceRef # Define a 2x3 float32 input tensor for the graph input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU()) with Graph("var_demo", input_types=[input_type]) as graph: x = graph.inputs[0].tensor # Variance along axis 1 (last dimension of each row) vr = x.var(axis=1) print(f"Input shape: {x.shape}") # [2, 3] print(f"Var shape: {vr.shape}") # [2, 1] ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
--- ## Value ## `Value` {#max.graph.Value} > class max.graph.Value Represents a symbolic value within a Graph. A Value can represent the output of a node, the arguments of a Graph (as seen from within its body), and more generally any symbolic value available within the Graph. Other nodes receive Value values as inputs to form a computation graph. A Value may also refer to an existing input or output of a node, and you can change them, such as by swapping a new Value. Conceptually, think of a Value as an edge in the dataflow graph, with the other end being the user of that value. The following example shows how to work with Values in a graph to create a simple computation: ```python from max.graph import Graph, ops, Value from max.dtype import DType import numpy as np with Graph("value_example") as graph: # Create input values a = ops.constant(np.array([1, 2, 3]), dtype=DType.float32, device=DeviceRef.CPU()) b = ops.constant(np.array([4, 5, 6]), dtype=DType.float32, device=DeviceRef.CPU()) # Use values to perform operations c = a + b # c is a Value representing the addition # Demonstrate that the result is a Value print(f"Type of c: {type(c)}") print(f"Is c a Value? {isinstance(c, Value)}") ``` Similar to a regular variable, a Value has a data type. ### `buffer` {#max.graph.Value.buffer} > property buffer: [BufferValue](BufferValue.md#max.graph.BufferValue) Returns the Value as a [`BufferValue`](BufferValue.md#max.graph.BufferValue). Raises an exception if the Value is not a BufferValue. ### `from_mlir()` {#max.graph.Value.from_mlir} > classmethod from\_mlir(value) Creates a [`Value`](#max.graph.Value) from an MLIR value.
**Parameters:**
value (Value\[MlirType]) – The MLIR value to wrap.
**Return type:**
[Value](#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]
### `opaque` {#max.graph.Value.opaque} > property opaque: \_OpaqueValue Returns the Value as an `_OpaqueValue`. Raises an exception if the Value is not a \_OpaqueValue. ### `tensor` {#max.graph.Value.tensor} > property tensor: [TensorValue](TensorValue.md#max.graph.TensorValue) Returns the Value as a [`TensorValue`](TensorValue.md#max.graph.TensorValue). Raises an exception if the Value is not a TensorValue. ### `to_mlir()` {#max.graph.Value.to_mlir} > to\_mlir() Converts the [`Value`](#max.graph.Value) to an MLIR value.
**Return type:**
Value\[MlirType]
### `type` {#max.graph.Value.type} > property type: [Type](type.md#max.graph.type.Type)\[MlirType] Returns the type of the [`Value`](#max.graph.Value) as a `Type`. --- ## Weight ## `Weight` {#max.graph.Weight} > class max.graph.Weight(\*args, \*\*kwargs) Bases: [`TensorValue`](TensorValue.md#max.graph.TensorValue) Represents a value in a Graph that can be loaded at a later time. Weights can be initialized outside of a Graph and are lazily-added to the parent graph when used. If there is no parent graph when a weight is used, an error will be raised. ### `align` {#max.graph.Weight.align} > align: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `device` {#max.graph.Weight.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) The device where the weight resides. ### `dtype` {#max.graph.Weight.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) The data type of the weight. ### `original_dtype_and_shape` {#max.graph.Weight.original_dtype_and_shape} > property original\_dtype\_and\_shape: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[DType](../dtype.md#max.dtype.DType), [Shape](shape.md#max.graph.shape.Shape)] The original dtype and shape of this weight. This property should be used to store the original weight’s dtype and shape the quantization encoding forces the weight to be loaded as uint8. ### `quantization_encoding` {#max.graph.Weight.quantization_encoding} > quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) ### `shape` {#max.graph.Weight.shape} > property shape: [Shape](shape.md#max.graph.shape.Shape) The shape of the weight. For sharded weights, returns the shape of the shard. Otherwise, returns the original weight shape. ### `shard()` {#max.graph.Weight.shard} > shard(devices) Creates sharded views of this Weight across multiple devices. This Weight must have sharding\_strategy defined. The shard objects returned are also Weight objects, but cannot be sharded further.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded weights, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Weight](#max.graph.Weight)]
### `shard_idx` {#max.graph.Weight.shard_idx} > shard\_idx: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `sharding_strategy` {#max.graph.Weight.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Gets the weight sharding strategy. --- ## dim Library for graph dimension types. ## `AlgebraicDim` {#max.graph.dim.AlgebraicDim} > class max.graph.dim.AlgebraicDim(value) An algebraic tensor dimension to enable expressions over symbolic dimensions. That is, any expression over a symbolic dimension returns `AlgebraicDim`. Furthermore, algebraic dimensions automatically simplify into a canonical form. The following example demonstrates how to create and use algebraic dimensions with symbolic values: ```python from max.graph import AlgebraicDim, Dim isinstance(Dim("batch") * 5, AlgebraicDim) # Returns True print(Dim("batch") * 5) # Outputs: batch * 5 -Dim("x") - 4 == -(Dim("x") + 4) # Returns True ```
**Parameters:**
attr (ParamOperatorAttr)
### `apply()` {#max.graph.dim.AlgebraicDim.apply} > classmethod apply(op, \*operands)
**Parameters:**
* op (POC) * operands ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)])
### `attr` {#max.graph.dim.AlgebraicDim.attr} > attr: ParamOperatorAttr ### `from_mlir()` {#max.graph.dim.AlgebraicDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension. * attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.AlgebraicDim.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)] Lists the symbolic dimension names on which this dim depends. ### `to_mlir()` {#max.graph.dim.AlgebraicDim.to_mlir} > to\_mlir() Creates an mlir.Attribute representing this dimension. This is used internally when constructing tensor MLIR types.
**Returns:**
An mlir.Attribute in the context representing the dimension.
**Return type:**
ParamOperatorAttr
## `Dim` {#max.graph.dim.Dim} > class max.graph.dim.Dim(value) A tensor dimension. Tensor dimensions can be one of three types: * **Static**: Known size * **Symbolic**: Unknown size but named * **Algebraic**: Unknown size has an algebraic expression In most cases, you don’t need to work with a `Dim` directly. Instead, use conversion constructors: ```python from max.graph import Dim, TensorType, DeviceRef tensor_type = TensorType(DType.int64, ("batch", 10), device=DeviceRef.CPU()) ``` This creates a tensor type with two dimensions: * A symbolic “batch” dimension * A static dimension of size 10 For explicit dimension construction, use the following helpers: ```python from max.graph import Dim some_dims = [ SymbolicDim("batch"), StaticDim(5), AlgebraicDim(Dim("batch") + 1), ] ``` Constraining tensor dimensions is one important way to improve model performance. If tensors have unknown dimensions, we can’t optimize them as aggressively. Symbolic tensors allow the compiler to learn constraints on a specific dimension (eg. if 2 inputs have the same batch dimension), but static dims are the easiest to optimize and therefore the easiest to create and work with.
**Parameters:**
value (DimLike)
### `from_mlir()` {#max.graph.dim.Dim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension. * attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.Dim.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)] Lists the symbolic dimension names on which this dim depends. ### `to_mlir()` {#max.graph.dim.Dim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
TypedAttr
## `StaticDim` {#max.graph.dim.StaticDim} > class max.graph.dim.StaticDim(value) A static tensor dimension. Static tensor dimensions will always have exactly the same value, and are key to good model performance. The following example shows how static dimensions can be created implicitly: ```python from max.graph import TensorType from max.dtype import DType tensor = TensorType(DType.int64, (4, 5)) ```
**Parameters:**
dim ([int](https://docs.python.org/3/library/functions.html#int))
### `dim` {#max.graph.dim.StaticDim.dim} > dim: [int](https://docs.python.org/3/library/functions.html#int) The size of the static dimension. ### `from_mlir()` {#max.graph.dim.StaticDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension. * attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.StaticDim.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)] Lists the symbolic dimension names on which this dim depends. ### `to_mlir()` {#max.graph.dim.StaticDim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
IntegerAttr
## `SymbolicDim` {#max.graph.dim.SymbolicDim} > class max.graph.dim.SymbolicDim(value) A symbolic tensor dimension. Symbolic dimensions represent named dimensions in MO tensor types. Symbolic dimensions don’t have a static value, but they allow a readable name to understand what’s going on in the model IR better, and they also allow users to hint to the compiler that two dimensions will have the same value, which can often allow important speedups. In tensor type notation: ```default !mo.tensor<[batch, x, 10], si32]> ``` The first and second dimensions are named `batch` and `x` respectively. Creating a `SymbolicDim`: ```python dim = SymbolicDim("name") ``` Using `SymbolicDim` in a `TensorType`: ```python tensor_type = TensorType(DType.bool, (SymbolicDim("batch"), SymbolicDim("x"), 10)) ```
**Parameters:**
name ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `from_mlir()` {#max.graph.dim.SymbolicDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension. * attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `name` {#max.graph.dim.SymbolicDim.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the dimension. ### `parameters` {#max.graph.dim.SymbolicDim.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)] Lists the symbolic dimension names on which this dim depends. ### `to_mlir()` {#max.graph.dim.SymbolicDim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
ParamDeclRefAttr
--- ## graph (Graph) APIs to build inference graphs for MAX Engine with Python. ## Classes * [`BufferValue`](/max/api/python/graph/BufferValue): Represents a mutable semantic tensor within a Graph. * [`Graph`](/max/api/python/graph/Graph): Represents a graph for MAX Engine. * [`KernelLibrary`](/max/api/python/graph/KernelLibrary): Represents a library with custom ops. * [`TensorValue`](/max/api/python/graph/TensorValue): Represents a value semantic tensor within a Graph. * [`Value`](/max/api/python/graph/Value): Represents a symbolic value within a Graph. * [`Weight`](/max/api/python/graph/Weight): Represents a weight value in a graph. ## Modules * [`dim`](/max/api/python/graph/dim): APIs for graph value tensor dimensions. * [`ops`](/max/api/python/graph/ops): Ops you can add when staging a graph. * [`quantization`](/max/api/python/graph/quantization): APIs to quantize graph tensors. * [`shape`](/max/api/python/graph/shape): APIs for graph value tensor shapes. * [`type`](/max/api/python/graph/type): APIs for graph value types. * [`weights`](/max/api/python/graph/weights): APIs for loading weights into a graph. --- ## ops Implements operations used when staging a graph. This module provides operations for building computational graphs in MAX. These operations create, transform, and manipulate tensor values within the graph. You can also use functions in [Graph](/max/api/python/graph/Graph) to add constant values to your graph with operations like [constant()](/max/api/python/graph/ops#max.graph.ops.constant). The [TensorValue](/max/api/python/graph/TensorValue/) type (returned by most operations) implements various dunder methods to support operations between TensorValues, such as + for addition, \* for multiplication, and @ for matrix multiplication. It also provides convenience methods like [reshape()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.reshape) and [flatten()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.flatten). ### `Callable` {#max.graph.ops.Callable} > class max.graph.ops.Callable ### `DeviceRef` {#max.graph.ops.DeviceRef} > class max.graph.ops.DeviceRef(device\_type, id=0) A symbolic device representation. DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in mlir. The following example demonstrates how to create and use device references: ```python from max.graph import DeviceRef gpu_device = DeviceRef.GPU() print(gpu_device) # Outputs: gpu:0 ## Create a CPU device with specific id cpu_device = DeviceRef.CPU(id=1) print(cpu_device) # Outputs: cpu:1 ```
**Parameters:**
* device\_type ([DeviceKind](type.md#max.graph.type.DeviceKind)) * id ([int](https://docs.python.org/3/library/functions.html#int))
#### `CPU()` {#max.graph.ops.DeviceRef.CPU} > static CPU(id=0) Static Method for creating a CPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `GPU()` {#max.graph.ops.DeviceRef.GPU} > static GPU(id=0) Static Method for creating a GPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `device_type` {#max.graph.ops.DeviceRef.device_type} > device\_type: [DeviceKind](type.md#max.graph.type.DeviceKind) #### `from_device()` {#max.graph.ops.DeviceRef.from_device} > static from\_device(device)
**Parameters:**
device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef))
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `from_mlir()` {#max.graph.ops.DeviceRef.from_mlir} > static from\_mlir(attr) Returns a device from an mlir attribute
**Parameters:**
attr (DeviceRefAttr)
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `id` {#max.graph.ops.DeviceRef.id} > id: [int](https://docs.python.org/3/library/functions.html#int) #### `is_cpu()` {#max.graph.ops.DeviceRef.is_cpu} > is\_cpu() Returns true if the device is a CPU device.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
#### `is_gpu()` {#max.graph.ops.DeviceRef.is_gpu} > is\_gpu() Returns true if the device is a GPU device.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
#### `to_device()` {#max.graph.ops.DeviceRef.to_device} > to\_device() Convert device reference to a concrete driver Device.
**Return type:**
[Device](../driver.md#max.driver.Device)
#### `to_mlir()` {#max.graph.ops.DeviceRef.to_mlir} > to\_mlir() Returns a mlir attribute representing device.
**Return type:**
DeviceRefAttr
### `InterpolationMode` {#max.graph.ops.InterpolationMode} > class max.graph.ops.InterpolationMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Interpolation modes for image resize operations. This enum defines the available interpolation methods that can be used when resizing tensors. Currently only BICUBIC is implemented, with BILINEAR and NEAREST planned for future support. #### `BICUBIC` {#max.graph.ops.InterpolationMode.BICUBIC} > BICUBIC = 'bicubic' #### `BILINEAR` {#max.graph.ops.InterpolationMode.BILINEAR} > BILINEAR = 'bilinear' #### `NEAREST` {#max.graph.ops.InterpolationMode.NEAREST} > NEAREST = 'nearest' ### `TensorType` {#max.graph.ops.TensorType} > class max.graph.ops.TensorType(dtype, shape, device, \_layout=None) A symbolic [`TensorType`](#max.graph.ops.TensorType). This is not an eager tensor type! This contains no actual data, but instead represents the type of a value at some point in time during model execution. Most internal values in a model will be tensors. This type represents their element type (`dtype`) and dimensions (`dims`) at a specific point during model computation. It allows us to do some optimistic optimizations and shape inference during graph construction, and to provide more detailed shape information to the compiler for further optimization passes. The following example shows how to create a tensor type with static dimensions and access its properties: ```python from max.graph import TensorType from max.dtype import DType ## Create a tensor type with float32 elements and static dimensions 2x3 tensor_type = TensorType(DType.float32, (2, 3)) print(tensor_type.dtype) # Outputs: DType.float32 print(tensor_type.shape) # Outputs: [2, 3] ``` It can also represent a fully dynamic rank tensor. The presence of dynamic rank tensors in a graph will often degrade performance dramatically and prevents many classes of optimizations. An optional device (`device`) can also be provided to indicate the explicit device the tensor is associated with.
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType)) * shape ([Shape](shape.md#max.graph.shape.Shape)) * device ([DeviceRef](type.md#max.graph.type.DeviceRef)) * \_layout ([FilterLayout](type.md#max.graph.type.FilterLayout) | None)
#### `as_buffer()` {#max.graph.ops.TensorType.as_buffer} > as\_buffer() Returns the analogous buffer type.
**Return type:**
[BufferType](type.md#max.graph.type.BufferType)
#### `from_mlir()` {#max.graph.ops.TensorType.from_mlir} > classmethod from\_mlir(type) Constructs a tensor type from an MLIR type.
**Parameters:**
* t – The MLIR Type object to parse into a tensor type. * type (TensorType)
**Returns:**
The tensor type represented by the MLIR Type value.
**Return type:**
[TensorType](type.md#max.graph.type.TensorType)
#### `to_mlir()` {#max.graph.ops.TensorType.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
TensorType
### `abs()` {#max.graph.ops.abs} > max.graph.ops.abs(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `acos()` {#max.graph.ops.acos} > max.graph.ops.acos(x) Computes the arccosine (inverse cosine) of the input tensor. Returns values in the range \[0, π] for inputs in \[-1, 1]. Creates a new op node to compute the elementwise arccosine of a symbolic tensor and adds it to the graph, returning the symbolic result. ```python def acos_graph(): input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU()) with Graph("acos_graph", input_types=(input_type,)) as graph: x = graph.inputs[0] out = ops.acos(x) graph.output(out) ```
**Parameters:**
x ([TensorValue](TensorValue.md#max.graph.TensorValue)) – Input tensor with values in \[-1, 1]. If values are outside this domain, they will be clamped to the valid range.
**Returns:**
* the same dtype as the input * the same shape as the input
**Return type:**
Arccosine of the input in radians \[0, π]. The result will have
**Raises:**
* Error – If the symbol doesn’t represent a tensor value. * Error – If the input is not a floating-point dtype.
### `add()` {#max.graph.ops.add} > max.graph.ops.add(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `allgather()` {#max.graph.ops.allgather} > max.graph.ops.allgather(inputs, signal\_buffers, axis=0) Collective allgather operation. This op is a collective op which takes in tensors from different devices and outputs tensors on different devices. In particular, this operation will gather the inputs across different devices and concatenates them along the specified dimension. The result is then broadcasted back to the same devices that the inputs came from.
**Parameters:**
* inputs ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – The input tensors to gather. * signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue]) – Device buffer values used for synchronization. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension to concatenate the input tensors. Defaults to 0.
**Returns:**
An iterable outputs which all hold the gathered output. Each output tensor contains the concatenation of all inputs along the specified dimension.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
### `argmax()` {#max.graph.ops.argmax} > max.graph.ops.argmax(x, axis=-1) Reduces a symbolic tensor using an argmax operation. When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the argmax operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `argmin()` {#max.graph.ops.argmin} > max.graph.ops.argmin(x, axis=-1) Reduces a symbolic tensor using an argmin operation. When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the argmin operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `argsort()` {#max.graph.ops.argsort} > max.graph.ops.argsort(x, ascending=True) Returns the indices that would sort a tensor. This function returns the indices that would sort the input tensor along its first dimension. The returned indices are of type int64.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – Input tensor to be sorted. * ascending ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True (default), sort in ascending order. If False, sort in descending order.
**Returns:**
A tensor of indices of the same shape as the input tensor.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `as_interleaved_complex()` {#max.graph.ops.as_interleaved_complex} > max.graph.ops.as\_interleaved\_complex(x) Reshapes the input symbolic tensor as complex from alternating (real, imag).
**Parameters:**
* interleaved – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `assert_same_device()` {#max.graph.ops.assert_same_device} > max.graph.ops.assert\_same\_device(\*values, \*\*named\_values)
**Parameters:**
* values ([TensorValue](TensorValue.md#max.graph.TensorValue) | [BufferValue](BufferValue.md#max.graph.BufferValue)) * named\_values ([TensorValue](TensorValue.md#max.graph.TensorValue) | [BufferValue](BufferValue.md#max.graph.BufferValue))
**Return type:**
None
### `atanh()` {#max.graph.ops.atanh} > max.graph.ops.atanh(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `avg_pool2d()` {#max.graph.ops.avg_pool2d} > max.graph.ops.avg\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False, count\_boundary=True) Perform a 2D average pooling operation on the input tensor. This function applies a 2D average pooling operation to the input tensor \[N, H, W, C]. The pooling operation slides a window of size kernel\_size over the input tensor, and computes the average value within each window.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to perform the pooling operation on. * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks. * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension. * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements. * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs. * ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, use ceil instead of floor to compute the output shape. * count\_boundary ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, count the padding elements when computing the average.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `band_part()` {#max.graph.ops.band_part} > max.graph.ops.band\_part(x, num\_lower=None, num\_upper=None, exclude=False) Masks out everything except a diagonal band of an input matrix. Copies a tensor setting everything outside the central diagonal band of the matrices to zero, where all but the last two axes are effectively batches, and the last two axes define sub matrices. Assumes the input has dimensions \[I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by ```python out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n]. ``` with the indicator function: ```python in_band(m, n) = ((num_lower is None || (m - n) <= num_lower)) && (num_upper is None || (n - m) <= num_upper)) ```
**Parameters:**
* input – The input to mask out. * num\_lower ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of diagonal bands to include below the central diagonal. If None, include the entire lower triangle. * num\_upper ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of diagonal bands to include above the central diagonal. If None, include the entire upper triangle. * exclude ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, invert the selection of elements to mask. Elements in the band are set to zero. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input tensor rank is less than 2, or if num\_lower/num\_upper are out of bounds for statically known dimensions.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `broadcast_to()` {#max.graph.ops.broadcast_to} > max.graph.ops.broadcast\_to(x, shape, out\_dims=None) Broadcasts a symbolic tensor. Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions. * shape ([TensorValue](TensorValue.md#max.graph.TensorValue) | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as a list of dimensions. Dynamic dimensions are not allowed. * out\_dims ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – Output dims used only for tensor-valued shape.
**Returns:**
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as `shape`.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – if a tensor-valued shape is passed without out\_dims.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `buffer_create()` {#max.graph.ops.buffer_create} > max.graph.ops.buffer\_create(type) Creates a buffer of the given type.
**Parameters:**
type ([BufferType](type.md#max.graph.type.BufferType)) – The type of the resulting BufferValue
**Returns:**
A new BufferValue of the requested type.
**Return type:**
[BufferValue](BufferValue.md#max.graph.BufferValue)
### `buffer_load()` {#max.graph.ops.buffer_load} > max.graph.ops.buffer\_load(x) Loads the input buffer into a tensor. It loads the in-place mutable tensor to an immutable tensor graph value. This is semantically equivalent to a copy from the mutable tensor x to the mutable value-semantic tensor output.
**Parameters:**
x ([BufferValue](BufferValue.md#max.graph.BufferValue)) – The buffer to be loaded to a tensor.
**Returns:**
A tensor graph value representing a copy of the buffer loaded.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `buffer_store()` {#max.graph.ops.buffer_store} > max.graph.ops.buffer\_store(destination, source) Stores the input tensor into the in-out buffer. It stores the immutable input tensor x in the mutable tensor y. This is semantically equivalent to a copy from x tensor to the y buffer.
**Parameters:**
* x – The tensor to be stored in the buffer. * y – The buffer to store the tensor in. * destination ([BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue) * source (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
None
### `buffer_store_slice()` {#max.graph.ops.buffer_store_slice} > max.graph.ops.buffer\_store\_slice(destination, source, indices) Stores the input tensor to into a slice in the input buffer. It stores the immutable input tensor source in the mutable tensor destination. This is semantically equivalent to a copy from source tensor to a slice in the destination buffer at index specified by indices.
**Parameters:**
* destination ([BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue) – The buffer to store the tensor in. * source (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The tensor to be stored in the buffer. * indices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorValue](TensorValue.md#max.graph.TensorValue) | [int](https://docs.python.org/3/library/functions.html#int) | [slice](https://docs.python.org/3/library/functions.html#slice) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[slice](https://docs.python.org/3/library/functions.html#slice), [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | builtins.ellipsis]) – The index in the buffer where the tensor should be stored
**Return type:**
None
### `call()` {#max.graph.ops.call} > max.graph.ops.call(graph, \*args, prefix='') Call a graph with the provided arguments and return its results. This function invokes a previously defined graph, passing in the provided arguments and the current chain value, and returns the results. The body of the graph is ultimately inlined into the caller, so the chain value is only used for serialization if the subgraph’s body contains an operation that makes use of it in the first place. The current advantage of using subgraphs is that it offers a way to improve compile times for operations that are used repeatedly in a model. As a secondary benefit, it also makes the IR more readable by allowing control flow to be expressed in a more natural way.
**Parameters:**
* graph ([Graph](Graph.md#max.graph.Graph)) – The graph to call * \*args ([Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Arguments to pass to the called graph * prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Prefix to add to the names of any weights in the subgraph
**Returns:**
Either a single Value or a list of Values representing the graph outputs (excluding the chain value which is handled internally)
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
### `cast()` {#max.graph.ops.cast} > max.graph.ops.cast(x, dtype) Casts a symbolic tensor to a different data type.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to cast. * dtype ([DType](../dtype.md#max.dtype.DType)) – The target dtype to which the tensor is cast.
**Returns:**
A new symbolic tensor with the same shape as the input and the specified dtype.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `chunk()` {#max.graph.ops.chunk} > max.graph.ops.chunk(x, chunks, axis=0) Chunk the tensor into an exact number of chunks along the specified dim. **Example:** ```pycon >>> a = TensorValue([1, 2, 3, 4, 5]) >>> chunk(a, 2, 0) [TensorValue([1, 2]), TensorValue([3, 4])] ```
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The tensor to chunk. * chunks ([int](https://docs.python.org/3/library/functions.html#int)) – The number of chunks to split the tensor into. chunks must statically evenly divide x.shape\[axis]. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to split the tensor along.
**Returns:**
A list of chunks tensors.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
### `concat()` {#max.graph.ops.concat} > max.graph.ops.concat(original\_vals, axis=0) Concatenates a list of symbolic tensors along an axis. Joins multiple tensors along a specified dimension. This operation requires the functional API since it operates on multiple tensors. All input tensors must have the same rank and the same size in all dimensions except the concatenation axis. ```python import max.functional as F from max.tensor import Tensor ## Create two 2x2 matrices a = Tensor.constant([[1, 2], [3, 4]]) b = Tensor.constant([[5, 6], [7, 8]]) ## Concatenate along axis 0 (rows) - stacks vertically vertical = F.concat([a, b], axis=0) print(f"Concatenated along axis 0: {vertical.shape}") ## Output: Concatenated along axis 0: [Dim(4), Dim(2)] print(vertical) ## [[1, 2], ## [3, 4], ## [5, 6], ## [7, 8]] ## Concatenate along axis 1 (columns) - joins horizontally horizontal = F.concat([a, b], axis=1) print(f"Concatenated along axis 1: {horizontal.shape}") ## Output: Concatenated along axis 1: [Dim(2), Dim(4)] print(horizontal) ## [[1, 2, 5, 6], ## [3, 4, 7, 8]] ```
**Parameters:**
* original\_vals ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – The list of symbolic tensor values to concatenate. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension other than `axis`. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape. For instance, `concat(vs, -1)` will concatenate along the last dimension.
**Returns:**
A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `cond()` {#max.graph.ops.cond} > max.graph.ops.cond(pred, out\_types, then\_fn, else\_fn) Conditionally execute one of two branches based on a boolean predicate. Both branches must return the same number and types of values as specified in `out_types`. Buffer mutations in branches are tracked automatically through the chain mechanism. Examples: 1. Basic conditional with return values: > ```python > def then_fn(): > return ops.constant(1, DType.int32, device=DeviceRef.CPU()) > def else_fn(): > return ops.constant(0, DType.int32, device=DeviceRef.CPU()) > ​ > result = ops.cond( > pred, > [TensorType(DType.int32, [], device=device)], > then_fn, > else_fn > ) > ``` 2. Conditional with buffer mutations: > ```python > def then_fn(): > ops.inplace_custom("increment", device=buffer.device, values=[buffer]) > def else_fn(): > ops.inplace_custom("decrement", device=buffer.device, values=[buffer]) > ​ > ops.cond(pred, None, then_fn, else_fn) > ``` :: :param pred: Boolean scalar tensor of type `DType.bool` determining branch execution :param out\_types: Expected output types for both branches. Use [`None`](https://docs.python.org/3/library/constants.html#None) for branches that don’t return values :param then\_fn: Callable executed when `pred` is True. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None) :param else\_fn: Callable executed when `pred` is False. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None)
**Returns:**
List of output values from executed branch. Returns empty list when `out_types` is [`None`](https://docs.python.org/3/library/constants.html#None)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If branches return different numbers of results or result types don’t match `out_types`
**Parameters:**
* pred (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * out\_types ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) * then\_fn ([Callable](#max.graph.ops.Callable)\[\[], [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None]) * else\_fn ([Callable](#max.graph.ops.Callable)\[\[], [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
:::note Note Buffer operations in branches automatically update the global chain state to maintain mutation ordering constraints ::: ### `constant()` {#max.graph.ops.constant} > max.graph.ops.constant(value, dtype=None, device=None) Adds a node representing a constant operation. The value of this constant will have the type TensorType with the same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions. The constant will be loaded with the specified dtype. If the constant does not fit within the specified dtype, an error is raised. Warning: Loading the constant could result in precision loss. For example, loading 16777217 as a float32 will result in 16777216.0.
**Parameters:**
* value ([DLPackArray](../driver.md#max.driver.DLPackArray) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Number | NestedArray]] | [float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The constant’s value. * dtype ([DType](../dtype.md#max.dtype.DType) | None) – The constant tensor’s element type. * device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef) | None) – The device the constant lives on.
**Returns:**
A graph value containing the constant data as an attribute.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `constant_external()` {#max.graph.ops.constant_external} > max.graph.ops.constant\_external(name, type) Registers an external constant (weight) in the graph of a given type. Two external constants with the same name and type refer to the same weight. Two external constants with the same name and different types are incompatible and will fail compilation.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the external constant. This should be the fully-qualified weight name and must be unique. * type ([TensorType](type.md#max.graph.type.TensorType)) – The type of the constant value.
**Returns:**
A tensor value of the specified type, representing the weight value associated with the name at compile time.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `conv2d()` {#max.graph.ops.conv2d} > max.graph.ops.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Computes the 2-D convolution product of the input with the given filter, bias, strides, dilations, paddings, and groups. The op supports 2-D convolution, with the following layout assumptions: * input x has NHWC layout, i.e., (batch\_size, height, width, in\_channels) * filter has layout RSCF, i.e., (height, width, in\_channels / num\_groups, out\_channels) * bias has shape (out\_channels,) The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2-D convolution, dim1 here represents H and dim2 represents W. In Python like syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield: ```python input = [ [1, 2, 3], [4, 5, 6] ] ## Shape is 2x3 padded_input = [ [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0] ] ## Shape is 3x6 ``` This op currently only supports strides and padding on the input.
**Parameters:**
* input – An NHWC input tensor to perform the convolution upon. * filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout: (height, width, in\_channels / num\_groups, out\_channels). * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the convolution operation. * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points. * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input. * groups ([int](https://docs.python.org/3/library/functions.html#int)) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) * input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `conv2d_transpose()` {#max.graph.ops.conv2d_transpose} > max.graph.ops.conv2d\_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output\_paddings=(0, 0), bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Computes the 2-D deconvolution of the input with the given filter, strides, dilations, paddings, and groups. The op supports the transpose (gradient) of convolution, with the following layout assumptions: (note the out\_channel is w\.r.t. the original convolution) * input x has NHWC layout, i.e., (batch\_size, height, width, in\_channels) * filter has layout RSCF, i.e., (kernel\_height, kernel\_width, out\_channels, in\_channels) * bias has shape (out\_channels,) The padding values are expected to take the form in the form \[\[0, 0], \[pad\_top, pad\_bottom], \[pad\_left, pad\_right], \[0, 0]]. This op effectively computes the gradient of a convolution with respect to its input (as if the original convolution operation had the same filter and hyperparameters as this op). A visualization of the computation can be found in . The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2D ConvTranspose, dim1 here represents H\_out and dim2 represents W\_out. In python like syntax, padding a 2x4 spatial output with \[0, 1, 2, 1] would yield: ```python output = [ [1, 2, 3, 4], [5, 6, 7, 8] ] ## Shape is 2x4 padded_input = [ [3], ] ## Shape is 1x1 ```
**Parameters:**
* input – An NHWC input tensor to perform the convolution upon. * filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout: (height, width, out\_channels, in\_channels). * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0. * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points. * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input. * output\_paddings ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – this argument is meant to resolve the ambiguity of multiple potential output shapes when any stride is greater than 1. Basically, we’ll add output\_paddings\[i] number of zeros at the end of output’s ith axis. We only support output\_paddings = 0. * bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) – tensor of shape (out\_channels,) * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `conv3d()` {#max.graph.ops.conv3d} > max.graph.ops.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.QRSCF) Computes the 3-D convolution product of the input with the given filter, strides, dilations, paddings, and groups. The op supports 3-D convolution, with the following layout assumptions: * input has NDHWC layout, i.e., (batch\_size, depth, height, width, in\_channels) * filter has layout RSCF, i.e., (depth, height, width, in\_channels / num\_groups, out\_channels) The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 3-D convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield: ```python input = [ [1, 2, 3], [4, 5, 6] ] ## Shape is 2x3 padded_input = [ [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0] ] ## Shape is 3x6 ``` This op currently only supports strides and padding on the input.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – An NDHWC input tensor to perform the convolution upon. * filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout: (depth, height, width, in\_channels / num\_groups, out\_channels). * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the convolution operation. * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points. * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input. * groups ([int](https://docs.python.org/3/library/functions.html#int)) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups. * bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) * input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout)) * filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied. Output shape = (batch\_size, depth, height, width, out\_channels).
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `cos()` {#max.graph.ops.cos} > max.graph.ops.cos(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `cumsum()` {#max.graph.ops.cumsum} > max.graph.ops.cumsum(x, axis=-1, exclusive=False, reverse=False) Computes the cumulative sum of the input tensor along the given axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to sum over. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the sum. If negative, indexes from the last dimension. For example, a value of -1 will compute the sum along the last dimension. * exclusive ([bool](https://docs.python.org/3/library/functions.html#bool)) – If set, start at 0 and exclude the final element. Otherwise, start with the first element. Said another way, cumsum computes \[sum(x\[…, :i, …]) for i in range(x.shape\[axis])]. If exclusive is set, the bounds are instead range(1, x.shape\[axis]). * reverse ([bool](https://docs.python.org/3/library/functions.html#bool)) – If set, start from the end. In other words, the first element will be the total sum, with each element following counting downwards; or \[sum(x\[…, i:, …]) for i in range(x.shape\[axis])].
**Returns:**
A symbolic tensor representing the result of the cumsum operation. The tensor will have the same type as the input tensor. The computed values will be the cumulative sum of the values along the given axis, according to the specified parameters: * if exclusive is set, the first value will be 0, and the last value will be excluded from the sum * if reverse is set, the sum will be computed starting at the back of the axis back to the front, rather than front-to-back
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `custom()` {#max.graph.ops.custom} > max.graph.ops.custom(name, device, values, out\_types, parameters=None) Creates a node to execute a custom graph operation in the graph. The custom op should be registered by annotating a function with the [@compiler.register](/mojo/manual/decorators/compiler-register/) decorator.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`. * values ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The op function’s arguments. * out\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The list of op function’s return type. * parameters ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](../dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel. * device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device that the op is assigned to. This becomes a target parameter to the kernel.
**Returns:**
Symbolic values representing the outputs of the op in the graph. These correspond 1:1 with the types passed as `out_types`.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
### `dequantize()` {#max.graph.ops.dequantize} > max.graph.ops.dequantize(encoding, quantized) Dequantizes a quantized tensor to floating point. NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only.
**Parameters:**
* encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding)) – The quantization encoding to use. * quantized ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The quantized tensor to dequantize.
**Returns:**
The dequantized result (a floating point tensor).
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `distributed_broadcast()` {#max.graph.ops.distributed_broadcast} > max.graph.ops.distributed\_broadcast(input, signal\_buffers) Broadcast tensor from source GPU to all GPUs. This op is a collective operation which broadcasts a tensor from the source GPU (where the input tensor resides) to all participating GPUs. Each GPU receives a copy of the input tensor.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – Input tensor to broadcast. The device where this tensor resides becomes the root/source of the broadcast. * signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue]) – Device buffer values used for synchronization. The number of signal buffers determines the number of participating GPUs.
**Returns:**
List of output tensors, one per device. Each output tensor has the same shape and dtype as the input tensor.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input tensor device is not found in signal buffer devices, if devices are not unique, or if there are fewer than 2 signal buffers.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
### `div()` {#max.graph.ops.div} > max.graph.ops.div(lhs, rhs) Divides two symbolic tensors using true division (Python operator /). For integer operands, this performs true division by promoting to float, matching Python’s / operator behavior. For floating-point operands, this performs standard floating-point division. Creates a new op node to compute the division of two symbol tensor values and adds it to the graph, returning the symbolic result.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The symbol to use as left side of the division. * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The symbol to use as right side of the division.
**Returns:**
A symbolic tensor value representing the output of the division. The result will have: : - floating-point dtype for integer operands, promoted dtype for mixed types * the same shape as the broadcast of the two input shapes.
**Raises:**
* Error – If the input values’ shapes are not compatible for broadcasting. * Error – If one of the input values has an unsupported dtype. * Error – If the two symbols are parts of different graphs.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `equal()` {#max.graph.ops.equal} > max.graph.ops.equal(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `erf()` {#max.graph.ops.erf} > max.graph.ops.erf(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `exp()` {#max.graph.ops.exp} > max.graph.ops.exp(x) Computes the elementwise exp (exponential) function of a symbolic tensor. Creates a new op node to compute the elementwise exponential function of a symbolic tensor and adds it to the graph, returning the symbolic result. The exp function is fundamental in neural networks, used in attention mechanisms, activation functions, and probability distributions. ```python import max.functional as F from max.tensor import Tensor ## Create input tensor x = Tensor.constant([0.0, 1.0, 2.0]) ## Compute exponential result = F.exp(x) print(result) ## Output: [1.0, 2.718..., 7.389...] ## (e^0 = 1, e^1 ≈ 2.718, e^2 ≈ 7.389) ``` `exp` is defined as `exp(x) = e^x`, where `e` is Euler’s number.
**Parameters:**
* value – The symbolic tensor to use as the input to the exp function computation. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the exp value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `flatten()` {#max.graph.ops.flatten} > max.graph.ops.flatten(x, start\_dim=0, end\_dim=-1) Flattens the specified dims of a symbolic tensor. The number and order of the elements in the tensor is unchanged. All dimensions from start\_dim to end\_dim (inclusive) are merged into a single output dim.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * start\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * end\_dim ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `floor()` {#max.graph.ops.floor} > max.graph.ops.floor(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `fold()` {#max.graph.ops.fold} > max.graph.ops.fold(input, output\_size, kernel\_size, stride=1, dilation=1, padding=0) Combines an array of sliding blocks into a larger containing tensor. The input tensor must have shape `(N, C * kernel_sizes, L)` where `N` is the batch dimension, `C` is the number of channels, `kernel_sizes` is the product of the kernel sizes, and `L` is the number of local blocks. The resulting output tensor will have shape `(N, C, output_shape[0], output_shape[1])`. `L`, the number of blocks, must be equivalent to: `prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)` where `d` is over all spatial dimensions.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The 3D tensor to fold with shape `(N, C * kernel sizes, L)`. * output\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Spatial dimensions of the output tensor. Must be a tuple of two ints. * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks. Must be a tuple of two ints. * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension (can be an int or a tuple of two ints). * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements. (can be an int or a tuple of two ints). * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs. (can be an int or a tuple of two ints).
**Returns:**
The folded 4D tensor with shape `(N, C, output_shape[0], output_shape[1])`.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `gather()` {#max.graph.ops.gather} > max.graph.ops.gather(input, indices, axis) Selects elements out of an input tensor by index.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to select elements from. * indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of index values to use for selection. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension which `indices` indexes from `input`. If negative, indexes relative to the end of the input tensor. For instance, `gather(input, indices, axis=-1)` will index against the last dimension of `input`.
**Returns:**
A new symbolic tensor representing the result of the gather operation.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `gather_nd()` {#max.graph.ops.gather_nd} > max.graph.ops.gather\_nd(input, indices, batch\_dims=0) Selects elements out of an input tensor by N-dimensional index. This operation performs N-dimensional indexing into `input` using `indices`. Unlike [`gather()`](#max.graph.ops.gather), which indexes along a single axis, `gather_nd()` allows indexing along multiple dimensions simultaneously. ```python input_shape = ["a", "b", "c", "d", "e"] indices_shape = ["a", "f", 3] input_type = TensorType(DType.bfloat16, input_shape) indices_type = TensorType(DType.int32, indices_shape) with Graph("gather_nd", input_types=[input_type, indices_type]) as graph: input, indices = graph.inputs gathered = ops.gather_nd(input, indices, batch_dims=1) print(gathered.type) ## Output: TensorType(dtype=DType.bfloat16, shape=["a", "f", "e"]) ``` In this example: * `batch_dims` is 1, so there’s 1 shared dimension at the beginning. * `indices` has an additional dimension “f” which becomes part of the output. * The last dimension of `indices` is the index vector; values in this vector are interpreted to be indices into “b”, “c”, and “d”. * Since `batch_dims (1) + index size (3) < input.rank (5)`, the remaining dimensions (in this case “e”) are sliced into the output as features.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to select elements from. * indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of index values to use for selection. The last dimension of this tensor must be static. This dimension will be used to index or slice into `input` immediately following `batch_dims` initial dimensions. The size of this index dimension is the number of dimensions it specifies. * batch\_dims ([int](https://docs.python.org/3/library/functions.html#int)) – The number of leading batch dimensions shared by `input` and `indices`; 0 by default. `input` and `indices` must exactly match up to their first `batch_dims` dimensions. This function does not broadcast.
**Returns:**
A new symbolic tensor representing the result of the gather operation. The output will have the same dtype as `input`, and will have shape depending on the inputs, in this order: * `input.shape[:batch_dims]` – The “broadcast” dimensions (though note that this function does not broadcast). These dimensions must be identical between `input` and `indices`. * `indices.shape[batch_dims:-1]` – The “gather” dimensions; this allows multi-dimensional tensors of indices. The last dimension is the index vector. * `input.shape[batch_dims + indices.shape[-1]:]` – The “slice” dimensions. If `batch_dims` < `input.rank - indices.shape[-1]` (again, this last is the index vector), then any following dimensions of the inputs are taken entirely as though slicing.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `gelu()` {#max.graph.ops.gelu} > max.graph.ops.gelu(x, approximate='none') Computes the elementwise gelu of a symbolic tensor. Creates a new op node to compute the elementwise gelu of a symbolic tensor and adds it to the graph, returning the symbolic result. For `approximate == "none"`, the exact gelu function is computed. For `approximate == "tanh"`, the approximation: $$ gelu(x) = 0.5 * x * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x**3))) $$ is used. For `approximate == "quick"`, the approximation: $$ gelu(x) = sigmoid(1.702 * x) * x $$ is used.
**Parameters:**
* value – The symbolic tensor to use as the input to the gelu computation. * x ([TensorValue](TensorValue.md#max.graph.TensorValue)) * approximate ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Returns:**
A new symbolic tensor value representing the output of the gelu value computation.
**Raises:**
* Error – If the symbol doesn’t represent a tensor value. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the approximation method is invalid.
### `greater()` {#max.graph.ops.greater} > max.graph.ops.greater(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `greater_equal()` {#max.graph.ops.greater_equal} > max.graph.ops.greater\_equal(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `hann_window()` {#max.graph.ops.hann_window} > max.graph.ops.hann\_window(window\_length, device, periodic=True, dtype=float32) Calculate a Hann window for a given length. Hann window function: $$ H[n] = 1/2 [1 - cos(2 * pi * n / (N - 1))] $$ where N is window\_length.
**Parameters:**
* window\_length ([int](https://docs.python.org/3/library/functions.html#int)) – The length of the window. * device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – The device to run the operation on. * periodic ([bool](https://docs.python.org/3/library/functions.html#bool)) – bool flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like stft(). hann\_window(L, periodic=True) == hann\_window(L + 1, periodic=False)\[:-1]) * dtype ([DType](../dtype.md#max.dtype.DType)) – The desired data type of the output tensor.
**Returns:**
A 1-D tensor of size (window\_length,) containing the window.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If window\_length is negative. * [TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If window\_length is not an integer.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `inplace_custom()` {#max.graph.ops.inplace_custom} > max.graph.ops.inplace\_custom(name, device, values, out\_types=None, parameters=None) Creates a node to execute an in-place custom graph operation in the graph. The custom op should be registered by annotating a function with the [@compiler.register](/mojo/manual/decorators/compiler-register/) decorator.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`. * device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device that the op is assigned to. This becomes a target parameter to the kernel. * values ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The op function’s arguments. * parameters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](../dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel. * out\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None)
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
### `irfft()` {#max.graph.ops.irfft} > max.graph.ops.irfft(input\_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input\_is\_complex=False, buffer\_size\_mb=512) Compute the inverse real FFT of the input tensor.
**Parameters:**
* input\_tensor (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to compute the inverse real FFT of. * n ([int](https://docs.python.org/3/library/functions.html#int) | None) – The size of the output tensor. Must be an int, and cannot be a symbolic Buffer. The input tensor will be padded or truncated to n // 2 + 1 along the specified axis. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to compute the inverse real FFT of. * normalization (Normalization | [str](https://docs.python.org/3/library/stdtypes.html#str)) – The normalization to apply to the output tensor. Can be “backward”, “ortho”, or “forward”. When “backward”, the output is divided by n. When “ortho”, the output is divided by sqrt(n). When “forward”, no normalization is applied. * input\_is\_complex ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the input tensor is already interleaved complex. The last dimension of the input tensor must be 2, and is excluded from the dimension referred to by axis. * buffer\_size\_mb ([int](https://docs.python.org/3/library/functions.html#int)) – The estimated size of a persistent buffer to use for storage of intermediate results. Needs to be the same across multiple calls to irfft within the same graph. Otherwise, multiple buffers will be allocated.
**Returns:**
The inverse real FFT of the input tensor. The shape of the output tensor is the same as the shape of the input tensor, except for the axis that the inverse real FFT is computed over, which is replaced by n.
### `is_inf()` {#max.graph.ops.is_inf} > max.graph.ops.is\_inf(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `is_nan()` {#max.graph.ops.is_nan} > max.graph.ops.is\_nan(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `layer_norm()` {#max.graph.ops.layer_norm} > max.graph.ops.layer\_norm(input, gamma, beta, epsilon) Performs layer normalization.
**Parameters:**
* input ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The input tensor to normalize. * gamma (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The gamma parameter of the normalization. * beta (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The beta parameter of the normalization. * epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – The epsilon parameter of the normalization.
**Returns:**
A graph tensor value with the normalization applied.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If gamma size doesn’t match the last dimension of input. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If beta size doesn’t match the last dimension of input. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If epsilon is not positive.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `log()` {#max.graph.ops.log} > max.graph.ops.log(x) Computes the elementwise natural logarithm of a symbolic tensor. Creates a new op node to compute the elementwise natural logarithm of a symbolic tensor and adds it to the graph, returning the symbolic result. The natural logarithm is used in loss functions, normalization, and probability calculations in machine learning. ```python import max.functional as F from max.tensor import Tensor ## Create input tensor (positive values only) x = Tensor.constant([1.0, 2.718, 7.389, 20.0]) ## Compute natural logarithm result = F.log(x) print(result) ## Output: [0.0, 1.0, 2.0, 2.996...] ## (log(1) = 0, log(e) = 1, log(e^2) = 2) ``` The natural logarithm function `log` is defined as the inverse of the exponential function `exp()`. In other words, it computes the value `y` in the equation `x = e^y` where `e` is Euler’s number. `log(x)` is undefined for `x <= 0` for real numbers. Complex numbers are currently unsupported.
**Parameters:**
* value – The symbolic tensor to use as the input to the natural logarithm computation. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the natural logarithm value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `log1p()` {#max.graph.ops.log1p} > max.graph.ops.log1p(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `logical_and()` {#max.graph.ops.logical_and} > max.graph.ops.logical\_and(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `logical_not()` {#max.graph.ops.logical_not} > max.graph.ops.logical\_not(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `logical_or()` {#max.graph.ops.logical_or} > max.graph.ops.logical\_or(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `logical_xor()` {#max.graph.ops.logical_xor} > max.graph.ops.logical\_xor(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `logsoftmax()` {#max.graph.ops.logsoftmax} > max.graph.ops.logsoftmax(value, axis=-1)
**Parameters:**
* value (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `masked_scatter()` {#max.graph.ops.masked_scatter} > max.graph.ops.masked\_scatter(input, mask, updates, out\_dim) Creates a new symbolic tensor where the updates are written to input where mask is true.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to. * mask (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of boolean values to update. * updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input. * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The new data-dependent dimension.
**Returns:**
A new symbolic tensor representing the result of the masked\_scatter operation.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `matmul()` {#max.graph.ops.matmul} > max.graph.ops.matmul(lhs, rhs) Computes the matrix multiplication of two tensor graph values. Performs general matrix multiplication with broadcasting. Matrix multiplication is fundamental to neural networks, used for linear transformations, attention mechanisms, and fully connected layers. ```python from max.tensor import Tensor ## Create two 2x2 matrices x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]]) # Shape: (2, 2) w = Tensor.constant([[5.0, 6.0], [7.0, 8.0]]) # Shape: (2, 2) ## Matrix multiply using @ operator (uses matmul internally) result = x @ w print("Matrix multiplication result:") print(result) ## Output: [[19.0, 22.0], ## [43.0, 50.0]] ## Computed as: result[i,j] = sum(x[i,k] * w[k,j]) ## Can also call directly via functional API import max.functional as F result2 = F.matmul(x, w) ## Same result as x @ w ``` If the lhs is 1D, it will be reshaped to `1xD`. If the rhs is 1D, it will be reshaped to `Dx1`. In both cases, the additional 1 dimensions will be removed from the output shape. For the multiplication, the innermost (rightmost) 2 dimensions are treated as a matrix. The lhs matrix will have the shape `MxK`. The rhs matrix will have the shape `KxN`. The output will have the shape MxN The `K` dimensions must be equivalent in both matrices. The remaining outer dimensions will be broadcasted.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The left-hand side input tensor. * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The right-hand side input tensor. * location – An optional location for a more specific error message.
**Returns:**
A tensor graph value representing the matrix product of `lhs` and `rhs`. For 2D inputs, the output shape is `(M, N)` where `lhs` is `(M, K)` and `rhs` is `(K, N)`. For higher-dimensional inputs, batch dimensions are preserved and the operation is applied to the last two dimensions of each input.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `max()` {#max.graph.ops.max} > max.graph.ops.max(x, y=None, /, axis=None) Overload for ops.elementwise.max and ops.reduction.max. * If two tensors are provided, axis is ignored and returns an elementwise maximum. * If one tensor is provided, compute ops.reduction.max on the tensor and axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * y (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) * axis ([int](https://docs.python.org/3/library/functions.html#int) | None)
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `max_pool2d()` {#max.graph.ops.max_pool2d} > max.graph.ops.max\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False) Perform a 2D max pooling operation on the input tensor. This function applies a 2D max pooling operation to the input tensor \[N, H, W, C]. The pooling operation slides a window of size kernel\_size over the input tensor, and selects the maximum value within each window.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to perform the pooling operation on. * kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks. * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension. * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements. * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs. * ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, use ceil instead of floor to compute the output shape.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `mean()` {#max.graph.ops.mean} > max.graph.ops.mean(x, axis=-1) Reduces a symbolic tensor using a mean operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the mean operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `min()` {#max.graph.ops.min} > max.graph.ops.min(x, y=None, /, axis=None) Overload for ops.elementwise.min and ops.reduction.min. * If two tensors are provided, axis is ignored and returns an elementwise minimum. * If one tensor is provided, compute ops.reduction.min on the tensor and axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * y (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) * axis ([int](https://docs.python.org/3/library/functions.html#int) | None)
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `mod()` {#max.graph.ops.mod} > max.graph.ops.mod(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `mul()` {#max.graph.ops.mul} > max.graph.ops.mul(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `negate()` {#max.graph.ops.negate} > max.graph.ops.negate(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `nonzero()` {#max.graph.ops.nonzero} > max.graph.ops.nonzero(x, out\_dim) Returns the indices of all nozero elements in a tensor. Returns a tensor of indices of the nonzero values in the given tensor. The return value is a 2D tensor of shape `[out_dim x rank_in]`, where out\_dim is the number of nonzero elements in the input tensor, and rank\_in is the rank of the input tensor. Indices are generated in row-major order.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor. * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The newly generated dimension that is sized for the number of nonzero elements.
**Returns:**
A symbolic tensor of indices
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `not_equal()` {#max.graph.ops.not_equal} > max.graph.ops.not\_equal(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `outer()` {#max.graph.ops.outer} > max.graph.ops.outer(lhs, rhs) Computes the outer product of two symbolic vectors.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The left side of the product. Whatever its shape, it will be flattened to a rank-1 vector. * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The right side of the product. Whatever its shape, it will be flattened to a rank-1 vector. Must have the same number of elements as lhs.
**Returns:**
A symbolic tensor representing the [outer product](\[https://en.wikipedia.org/wiki/Outer_product]\(https://en.wikipedia.org/wiki/Outer_product\)) of the two input vectors. It will have rank 2, with the dimension sizes being the number of elements of lhs and rhs respectively.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `pad()` {#max.graph.ops.pad} > max.graph.ops.pad(input, paddings, mode='constant', value=0) Pads a tensor with constant values. Adds padding to the input tensor using the specified padding values. Currently only constant padding mode is supported.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to pad. * paddings ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Sequence of padding values. The padding values are applied symmetrically to each dimension. For a tensor with rank N, paddings should contain 2\*N values: [pad\_before\_dim0, pad\_after\_dim0, pad\_before\_dim1, pad\_after\_dim1, …]. * mode ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['constant']) – The padding mode. Currently only “constant” is supported. * value (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The constant value to use for padding.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `permute()` {#max.graph.ops.permute} > max.graph.ops.permute(x, dims) Permutes all dimensions of a symbolic tensor.
**Parameters:**
* input – The input symbolic tensor to transpose. * dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – The desired ordering of the dimensions in the output tensor. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor with the dimensions permuted to match the passed in order. It has the same elements and dtype, but the order of the elements is different according to the permutation.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `pow()` {#max.graph.ops.pow} > max.graph.ops.pow(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `print()` {#max.graph.ops.print} > max.graph.ops.print(value, label='debug\_tensor') Prints the value of a tensor or a string during graph execution. This function is used to output the current value of a tensor and is primarily used for debugging purposes within the context of the Max Engine and its graph execution framework. This is particularly useful to verify the intermediate results of your computations are as expected. By printing the tensor values, you can visualize the data flowing through the graph, which helps in understanding how the operations are transforming the data. When labeling the function you can assign the output, making it easier to identify which tensor’s value is being printed, especially when there are multiple print statements in a complex graph. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]: input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU()) with Graph( "simple_add_graph", input_types=(input_type, input_type) ) as graph: lhs, rhs = graph.inputs out = ops.add(lhs, rhs) ops.print(out, label="addition_output") # Pass the output tensor here graph.output(out) print("final graph:", graph) ```
**Parameters:**
* value ([str](https://docs.python.org/3/library/stdtypes.html#str) | [TensorValue](TensorValue.md#max.graph.TensorValue)) – The value to print. Can be either a string or a TensorValue. * label ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A label to identify the printed value. Defaults to `debug_tensor`.
**Return type:**
None
### `qmatmul()` {#max.graph.ops.qmatmul} > max.graph.ops.qmatmul(encoding, config, lhs, \*rhs) Performs matrix multiplication between floating point and quantized tensors. This quantizes the `lhs` floating point value to match the encoding of the `rhs` quantized value, performs matmul, and then dequantizes the result. Beware that, compared to a regular matmul op, this one expects the `rhs` value to be transposed. For example, if the `lhs` shape is \[32, 64], and the quantized `rhs` shape is also `[32, 64]`, then the output shape is `[32, 32]`. That is, this function returns the result from: > dequantize(quantize(lhs) @ transpose(rhs)) The last two dimensions in `lhs` are treated as matrices and multiplied by `rhs` (which must be a 2D tensor). Any remaining dimensions in `lhs` are broadcast dimensions. NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only.
**Parameters:**
* encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding)) – The quantization encoding to use. * lhs ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The non-quantized, left-hand-side of the matmul. * \*rhs ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The transposed and quantized right-hand-side of the matmul and auxiliary tensor (if has). Must be rank 2 and in a supported \[quantization encoding] (/max/api/mojo/graph/quantization/). * config ([QuantizationConfig](quantization.md#max.graph.quantization.QuantizationConfig) | None)
**Returns:**
The dequantized result (a floating point tensor).
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `range()` {#max.graph.ops.range} > max.graph.ops.range(start, stop, step=1, out\_dim=None, \*, dtype, device) Creates a sequence of numbers. The sequence goes from start with increments of size step up to (but not including) stop. All arguments are mandatory and must have the same element type. Note the following restrictions on input values: 1. step must be non-zero 2. stop - start must be zero or have the same sign as step
**Parameters:**
* start (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The start of the range to generate. * stop (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The range will be generated up to, but not including, this value. * step (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The step size for the range. * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None) – The expected output dimensions returned by the range op. These will be assert at graph execution time to be correct. * device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device of the result tensor. * dtype ([DType](../dtype.md#max.dtype.DType)) – Data type of the result tensor. If not specified, defaults to float32 for numeric inputs or infers from tensor inputs.
**Returns:**
A symbolic tensor value containing the defined range of values.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `rebind()` {#max.graph.ops.rebind} > max.graph.ops.rebind(x, shape, message='', layout=None) Rebinds a symbolic tensor to a specified set of dimensions. This does not mutate the symbolic tensor passed in, but instead adds a runtime assert that the input symbolic shape is equivalent to `out_dims` shape. For example, if the input tensor shape has dynamic/unknown sizes, this will assert a fixed sizes that may be required for a subsequent operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to rebind. * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The symbolic shape to assert for `x`, as a list of [`Dim`](/max/api/python/graph/type/Dim) values. * message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The message printed if the rebind fails at runtime. * layout ([FilterLayout](type.md#max.graph.type.FilterLayout) | None) – A layout of the weights used by some operations like conv.
**Returns:**
A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to `out_dims`.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `relu()` {#max.graph.ops.relu} > max.graph.ops.relu(x) Computes the elementwise ReLU (Rectified Linear Unit) of a symbolic tensor. Creates a new op node to compute the elementwise ReLU of a symbolic tensor and adds it to the graph, returning the symbolic result. ReLU is defined as `relu(x) = max(0, x)`, setting all negative values to zero while leaving positive values unchanged. ReLU is one of the most common activation functions in neural networks due to its computational efficiency and effectiveness in addressing the vanishing gradient problem. ```python import max.functional as F from max.tensor import Tensor ## Create input with negative and positive values x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]) ## Apply ReLU activation result = F.relu(x) print(result) ## Output: [[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]] ## Negative values become 0, positive values unchanged ```
**Parameters:**
* value – The symbolic tensor to use as the input to the relu computation. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the relu value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `repeat_interleave()` {#max.graph.ops.repeat_interleave} > max.graph.ops.repeat\_interleave(x, repeats, axis=None, out\_dim=None) Repeats elements of a tensor along the given dimension. Modeled after `torch.repeat_interleave`, with the constraint that For example, given `repeats=2` and the following input: ```python ## Input tensor with shape (2, 2) input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]] ``` `repeat_interleave` with `axis=0`: ```python ## Output tensor with shape (4, 2) output = repeat_interleave(input, repeats=2, axis=0) ## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]] ``` `repeat_interleave` with `axis=1`: ```python ## Output tensor with shape (2, 4) output = repeat_interleave(input, repeats=2, axis=1) ## Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]] ``` `repeat_interleave` with `axis=None` (the default): `repeat_interleave` with `repeats=[2, 3]` and `axis=0`: ```python repeat_value = TensorValue([2, 3]) ## Output tensor with shape (5, 2) output = repeat_interleave(input, repeats=repeat_value, axis=0) ## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]] ``` ```python ## Output tensor with shape (8,) output = repeat_interleave(input, repeats=2) # axis = None ## Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] ```
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor. * repeats ([int](https://docs.python.org/3/library/functions.html#int) | [TensorValue](TensorValue.md#max.graph.TensorValue)) – The number of repetitions for each element. * axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The dimension along which to repeat values. If axis is not specified or None (the default), flatten the input array and repeat the flattened values. * out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None)
**Returns:**
A symbolic tensor with the elements interleaved.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `repeats` non-positive or if `axis` is out of range.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `reshape()` {#max.graph.ops.reshape} > max.graph.ops.reshape(x, shape) Reshapes a symbolic tensor. The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same. If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to reshape. * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as a list of dimensions. A single dimension may be -1.
**Returns:**
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as `shape`.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – if input and target shapes’ number of elements mismatch.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `resize()` {#max.graph.ops.resize} > max.graph.ops.resize(input, shape, interpolation=InterpolationMode.BILINEAR) Resize the input tensor to the given shape. This function resizes a tensor using the specified interpolation method. The tensor is expected to have NCHW format (batch, channels, height, width).
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to resize. Must have rank 4 in NCHW format. * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Desired output shape of length 4 corresponding to (N, C, H, W). * interpolation ([InterpolationMode](#max.graph.ops.InterpolationMode)) – Desired interpolation enum defined by InterpolationMode. Default is InterpolationMode.BILINEAR. Currently only BICUBIC is supported.
**Returns:**
A resized tensor with the shape specified by the shape argument.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input doesn’t have rank 4, shape has wrong number of elements, or unsupported interpolation mode is specified. * [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If single integer size or non-BICUBIC interpolation mode is specified.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `round()` {#max.graph.ops.round} > max.graph.ops.round(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `rsqrt()` {#max.graph.ops.rsqrt} > max.graph.ops.rsqrt(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `scatter()` {#max.graph.ops.scatter} > max.graph.ops.scatter(input, updates, indices, axis=-1) Creates a new symbolic tensor where the updates are written to input according to indices.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to. * updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input. * indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The positions in input to update. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which indices indexes into.
**Returns:**
A new symbolic tensor representing the result of the scatter operation.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `scatter_nd()` {#max.graph.ops.scatter_nd} > max.graph.ops.scatter\_nd(input, updates, indices) Creates a new symbolic tensor where the updates are scattered into input at specified indices.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to. * updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input. * indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A tensor of indices specifying where to write updates. Shape should be \[num\_updates, rank] for full indexing or \[num\_updates, k] for partial indexing where k < rank.
**Returns:**
A new symbolic tensor representing the result of the scatter\_nd operation.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `shape_to_tensor()` {#max.graph.ops.shape_to_tensor} > max.graph.ops.shape\_to\_tensor(shape) Converts a shape to a tensor. This is useful for using a shape attribute in an op that expects a tensor value.
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – the shape attribute of a tensor value.
**Returns:**
The TensorValue containing the same value as shape.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
**Example:** ```pycon >>> x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU()) >>> result = ops.stack([ ... x, ... ops.shape_to_tensor(x.shape), ... ]) TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)]) ``` ### `shard_and_stack()` {#max.graph.ops.shard_and_stack} > max.graph.ops.shard\_and\_stack(inputs, devices, axis=0) Shards a list of input tensors along a specified axis, producing multiple outputs. This operation takes multiple input tensors, splits each along the specified axis into len(devices) chunks, and returns one output tensor per device. Each output contains the chunks at the corresponding index stacked from all inputs along a new dimension 0. This is useful for distributing model weights across multiple devices in tensor parallel configurations. For example, with 2 inputs A and B, axis=0, and 2 devices: * Input A shape \[10, D], Input B shape \[10, D] * Output 0: stack(\[A\[0:5], B\[0:5]]) -> shape \[2, 5, D] on devices\[0] * Output 1: stack(\[A\[5:10], B\[5:10]]) -> shape \[2, 5, D] on devices\[1] With axis=1 and 2 devices: * Input A shape \[D, 10], Input B shape \[D, 10] * Output 0: stack(\[A\[:, 0:5], B\[:, 0:5]]) -> shape \[2, D, 5] on devices\[0] * Output 1: stack(\[A\[:, 5:10], B\[:, 5:10]]) -> shape \[2, D, 5] on devices\[1]
**Parameters:**
* inputs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – A list of symbolic tensors to shard. All tensors must have the same shape, dtype, and device. * devices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)]) – Target devices for each output tensor. The number of devices determines the number of splits. Each output tensor will be placed on the corresponding device. This enables direct host-to-device transfer without intermediate CPU storage. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to split each input tensor. Defaults to 0. Supports negative indexing (e.g., -1 for last axis).
**Returns:**
A list of len(devices) tensors, each with shape \[num\_inputs, D0, …, Daxis//len(devices), …, Dn-1] where the input shape is \[D0, …, Daxis, …, Dn-1]. Output i contains the stacked chunks at position i from all input tensors, placed on devices\[i].
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If inputs list is empty, if devices list is empty, if input tensors don’t have matching shapes, if the dimension size at the axis is not evenly divisible by len(devices), or if axis is out of bounds.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
### `sigmoid()` {#max.graph.ops.sigmoid} > max.graph.ops.sigmoid(x) Computes the elementwise sigmoid activation of a symbolic tensor. Creates a new op node to compute the elementwise sigmoid of a symbolic tensor and adds it to the graph, returning the symbolic result. Sigmoid is defined as `sigmoid(x) = 1 / (1 + exp(-x))`, mapping all input values to the range (0, 1). The sigmoid function is commonly used for binary classification tasks and as an activation function in neural networks, particularly in output layers for probability prediction. ```python import max.functional as F from max.tensor import Tensor ## Create input tensor x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]) ## Apply sigmoid activation result = F.sigmoid(x) print(result) ## Output: [[0.119, 0.269, 0.5], [0.731, 0.881, 0.953]] ## All values mapped to range (0, 1) ```
**Parameters:**
* value – The symbolic tensor to use as the input to the sigmoid computation. * x ([TensorValue](TensorValue.md#max.graph.TensorValue))
**Returns:**
A new symbolic tensor value representing the output of the sigmoid value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `silu()` {#max.graph.ops.silu} > max.graph.ops.silu(x) Computes the elementwise silu of a symbolic tensor. Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result. `silu` is defined as `silu(x) = x * sigmoid(x)`.
**Parameters:**
* value – The symbolic tensor to use as the input to the silu computation. * x ([TensorValue](TensorValue.md#max.graph.TensorValue))
**Returns:**
A new symbolic tensor value representing the output of the silu value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `sin()` {#max.graph.ops.sin} > max.graph.ops.sin(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `slice_tensor()` {#max.graph.ops.slice_tensor} > max.graph.ops.slice\_tensor(x, indices) Slices out a subtensor view of the input tensor based on indices. The semantics of [`slice_tensor()`](#max.graph.ops.slice_tensor) follow NumPy slicing semantics with the following restrictions: * Slice indices must not index out of `[-dim - 1, dim - 1]` for negative step, or `[-dim, dim]` for positive step. ```python ## Reverse a tensor. slice_tensor(x, [slice(None, None, -1)]) ## Unsqueeze the second last dimension of a tensor. slice_tensor(x, [..., None, slice(None)]) ```
**Returns:**
The sliced subtensor of x.
**Parameters:**
* x ([TensorValue](TensorValue.md#max.graph.TensorValue)) * indices (SliceIndices)
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `softmax()` {#max.graph.ops.softmax} > max.graph.ops.softmax(value, axis=-1)
**Parameters:**
* value (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * axis ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `split()` {#max.graph.ops.split} > max.graph.ops.split(x, split\_sizes, axis=0) Splits the input tensor into multiple tensors along a given dimension.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to split. * split\_sizes ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Sizes of each output tensor. Must add up to the split dimension axis. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension to split the input tensor. Must have a statically known dimension size.
**Returns:**
A list of tensors with the same length as split\_sizes, where each tensor has the same shape as the input except along the split dimension axis, where the size is given by the corresponding element in split\_sizes.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
### `sqrt()` {#max.graph.ops.sqrt} > max.graph.ops.sqrt(x) Computes the elementwise square root of a symbolic tensor. Creates a new op node to compute the elementwise square root of a symbolic tensor and adds it to the graph, returning the symbolic result. Square root is commonly used in normalization operations, distance calculations, and implementing mathematical operations like standard deviation. ```python import max.functional as F from max.tensor import Tensor ## Create tensor with positive values x = Tensor.constant([1.0, 4.0, 9.0, 16.0]) ## Compute square root result = F.sqrt(x) print(result) ## Output: [1.0, 2.0, 3.0, 4.0] ## Note: sqrt requires non-negative values ## For tensors with negative values, use abs first: y = Tensor.constant([1.0, -4.0, 9.0, -16.0]) result2 = F.sqrt(F.abs(y)) print(result2) ## Output: [1.0, 2.0, 3.0, 4.0] ```
**Parameters:**
* value – The symbolic tensor to use as the input to the sqrt computation. If it’s not a floating-point DType, an exception will be raised. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the sqrt value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `squeeze()` {#max.graph.ops.squeeze} > max.graph.ops.squeeze(x, axis) Removes a size-1 dimension from a symbolic tensor.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to squeeze. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension to remove from the input’s shape. If negative, this indexes from the end of the tensor. For example, `squeeze(v, -1)` squeezes the last dimension.
**Returns:**
A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `stack()` {#max.graph.ops.stack} > max.graph.ops.stack(values, axis=0) Stacks a list of tensors along a new axis.
**Parameters:**
* values ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape plus 1. For instance, `stack(vs, -1)` will create and stack along a new axis as the last dimension, aad `stack(vs, -2)` will create and stack along a new dimension which is inserted immediately before the last dimension.
**Returns:**
A new symbolic tensor representing the result of the stack. It will have rank `n+1` where `n` is the rank of each input tensor. Its size on each dimension other than `axis` will be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have size `len(values)`.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `sub()` {#max.graph.ops.sub} > max.graph.ops.sub(lhs, rhs)
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `sum()` {#max.graph.ops.sum} > max.graph.ops.sum(x, axis=-1) Reduces a symbolic tensor using a sum operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the sum operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `tanh()` {#max.graph.ops.tanh} > max.graph.ops.tanh(x) Computes the elementwise tanh (hyperbolic tangent) of a symbolic tensor. Creates a new op node to compute the elementwise tanh of a symbolic tensor and adds it to the graph, returning the symbolic result. Tanh is defined as `tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`, mapping all input values to the range (-1, 1). The tanh function is commonly used as an activation function in recurrent neural networks (RNNs) and as a hidden layer activation in feedforward networks. Unlike sigmoid which maps to (0, 1), tanh is zero-centered, which can help with gradient flow during training. ```python import max.functional as F from max.tensor import Tensor ## Create input tensor x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]) ## Apply tanh activation result = F.tanh(x) print(result) ## Output: [[-0.964, -0.762, 0.0], [0.762, 0.964, 0.995]] ## All values mapped to range (-1, 1) ```
**Parameters:**
* value – The symbolic tensor to use as the input to the tanh computation. If it’s not a floating-point DType, an exception will be raised. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the tanh value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `tile()` {#max.graph.ops.tile} > max.graph.ops.tile(x, repeats) Returns a new Tensor as the result of copying the input tensor N\_i times on each dimension, where N\_i = repeats\[i]. The i-th dimension of output shape will be the ith dimension of input shape multiplied by N\_i.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) * repeats ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `top_k()` {#max.graph.ops.top_k} > max.graph.ops.top\_k(input, k, axis=-1) Returns tensor with only top K values along given axis.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor from which to select top k. * k ([int](https://docs.python.org/3/library/functions.html#int)) – The number of values to select from input. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis from which to select top k.
**Returns:**
Top K values, Top K indices
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](TensorValue.md#max.graph.TensorValue), [TensorValue](TensorValue.md#max.graph.TensorValue)]
### `transfer_to()` {#max.graph.ops.transfer_to} > max.graph.ops.transfer\_to(x, device) Device-to-Device transfer operation. This op transfers the input tensor from its current device over to another. A device represents a computation unit, like CPU, GPU, etc. This op is useful for instance when working with accelerators, like GPU, where for instance one may need to move data from GPU to GPU, or from one GPU to CPU.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to transfer. * device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – The device to transfer to.
**Returns:**
A tensor transferred to device specified.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `transpose()` {#max.graph.ops.transpose} > max.graph.ops.transpose(x, axis\_1, axis\_2) Transposes two axes of a symbolic tensor. For more information, see [`transpose()`](TensorValue.md#max.graph.TensorValue.transpose).
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to transpose. * axis\_1 ([int](https://docs.python.org/3/library/functions.html#int)) – One of the two axes to transpose. If negative, this indexes from the end of the tensor. For example, `transpose(v, -1, -2)` transposes the last two axes. * axis\_2 ([int](https://docs.python.org/3/library/functions.html#int)) – The other axis to transpose. May also be negative to index from the end of the tensor.
**Returns:**
A new symbolic tensor with the two specified axes transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `trunc()` {#max.graph.ops.trunc} > max.graph.ops.trunc(x)
**Parameters:**
x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `unsqueeze()` {#max.graph.ops.unsqueeze} > max.graph.ops.unsqueeze(x, axis) Inserts a size-1 dimension into a symbolic tensor.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to unsqueeze. * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The index at which to insert a new dimension into the input’s shape. Elements at that index or higher are shifted back. If negative, it indexes relative 1 plus the rank of the tensor. For example, `unsqueeze(v, -1)` adds a new dimension at the end, and `unsqueeze(v, -2)` inserts the dimension immediately before the last dimension.
**Returns:**
A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the `axis` dimension is a static dimension of size 1.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `where()` {#max.graph.ops.where} > max.graph.ops.where(condition, x, y) Returns `condition ? x : y` (element-wise), where `cond`, `x` and `y` are input tensors.
**Parameters:**
* condition (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The condition tensor to use for selecting elementwise values. This tensor must have a boolean dtype. * x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – If the condition is true at a position, the value from the same position in this tensor will be selected. * y (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – If the condition is false at a position, the value from the same position in this tensor will be selected.
**Returns:**
A new symbolic tensor holding either values from either `x` or `y`, based on the elements in condition.
**Return type:**
[TensorValue](TensorValue.md#max.graph.TensorValue)
### `while_loop()` {#max.graph.ops.while_loop} > max.graph.ops.while\_loop(initial\_values, predicate, body) Execute a loop until the predicate evaluates to false. Both the predicate and body functions must take in as arguments the same number and types of values as specified in the init\_args. The predication function must return only a boolean scalar tensor of type `DType.bool`. The body function must return a list of values matching the types of init\_args, (or may return a value directly if there is only one). The following example demonstrates a basic while loop with a single argument: ```python from max.graph import Graph, ops from max.dtype import DType with Graph("while_loop_example") as g: x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU()) def pred(x): return x < 10 def body(x): return x + 1 result = ops.while_loop(x, pred, body) print(result) ``` The following example shows a while loop with multiple arguments: ```python from max.graph import Graph, ops from max.dtype import DType with Graph("while_loop_example") as g: x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU()) y = ops.constant(5, dtype=DType.int32, device=DeviceRef.CPU()) def pred(x, y): return ops.logical_and(x < 10, y < 15) def body(x, y): return [x + 1, y + 1] results = ops.while_loop((x, y), pred, body) print(results) ```
**Parameters:**
* initial\_values ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Initial values for loop arguments. Must be non-empty. * predicate ([Callable](#max.graph.ops.Callable)\[\[...], [TensorValue](TensorValue.md#max.graph.TensorValue)]) – Callable that takes loop arguments and returns a boolean scalar tensor of type `DType.bool` determining loop continuation. * body ([Callable](#max.graph.ops.Callable)\[\[...], [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Callable that takes loop arguments and returns updated values matching the types of init\_args.
**Returns:**
List of output values from the final loop iteration.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If init\_args is empty. * [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If any init\_arg is a `BufferValue`.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](TensorValue.md#max.graph.TensorValue)]
:::note Note Buffer operations are currently not supported. ::: --- ## quantization APIs to quantize graph tensors. This package includes a comprehensive set of tools for working with quantized models in MAX Graph. It defines supported quantization encodings, configuration parameters that control the quantization process, and block parameter specifications for different quantization formats. The module supports various quantization formats including 4-bit, 5-bit, and 6-bit precision with different encoding schemes. It also provides support for GGUF-compatible formats for interoperability with other frameworks. ## `BlockParameters` {#max.graph.quantization.BlockParameters} > class max.graph.quantization.BlockParameters(elements\_per\_block, block\_size) Parameters describing the structure of a quantization block. Block-based quantization stores elements in fixed-size blocks. Each block contains a specific number of elements in a compressed format.
**Parameters:**
* elements\_per\_block ([int](https://docs.python.org/3/library/functions.html#int)) * block\_size ([int](https://docs.python.org/3/library/functions.html#int))
### `block_size` {#max.graph.quantization.BlockParameters.block_size} > block\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `elements_per_block` {#max.graph.quantization.BlockParameters.elements_per_block} > elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) ## `QuantizationConfig` {#max.graph.quantization.QuantizationConfig} > class max.graph.quantization.QuantizationConfig(quant\_method, bits, group\_size, desc\_act=False, sym=False) Configuration for specifying quantization parameters that affect inference. These parameters control how tensor values are quantized, including the method, bit precision, grouping, and other characteristics that affect the trade-off between model size, inference speed, and accuracy.
**Parameters:**
* quant\_method ([str](https://docs.python.org/3/library/stdtypes.html#str)) * bits ([int](https://docs.python.org/3/library/functions.html#int)) * group\_size ([int](https://docs.python.org/3/library/functions.html#int)) * desc\_act ([bool](https://docs.python.org/3/library/functions.html#bool)) * sym ([bool](https://docs.python.org/3/library/functions.html#bool))
### `bits` {#max.graph.quantization.QuantizationConfig.bits} > bits: [int](https://docs.python.org/3/library/functions.html#int) ### `desc_act` {#max.graph.quantization.QuantizationConfig.desc_act} > desc\_act: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `group_size` {#max.graph.quantization.QuantizationConfig.group_size} > group\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `quant_method` {#max.graph.quantization.QuantizationConfig.quant_method} > quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `sym` {#max.graph.quantization.QuantizationConfig.sym} > sym: [bool](https://docs.python.org/3/library/functions.html#bool) = False ## `QuantizationEncoding` {#max.graph.quantization.QuantizationEncoding} > class max.graph.quantization.QuantizationEncoding(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Quantization encodings supported by MAX Graph. Quantization reduces the precision of neural network weights to decrease memory usage and potentially improve inference speed. Each encoding represents a different compression method with specific trade-offs between model size, accuracy, and computational efficiency. These encodings are commonly used with pre-quantized model checkpoints (especially GGUF format) or can be applied during weight allocation. The following example shows how to create a quantized weight using the Q4\_K encoding: ```python from max.graph.quantization import QuantizationEncoding from max.graph import Weight encoding = QuantizationEncoding.Q4_K quantized_weight = Weight( name="linear.weight", dtype=DType.uint8, shape=[4096, 4096], device=DeviceRef.GPU(0), quantization_encoding=encoding ) ``` MAX supports several quantization formats optimized for different use cases. ### `Q4_0` {#max.graph.quantization.QuantizationEncoding.Q4_0} > Q4\_0 Basic 4-bit quantization with 32 elements per block. ### `Q4_K` {#max.graph.quantization.QuantizationEncoding.Q4_K} > Q4\_K 4-bit K-quantization with 256 elements per block. ### `Q5_K` {#max.graph.quantization.QuantizationEncoding.Q5_K} > Q5\_K 5-bit K-quantization with 256 elements per block. ### `Q6_K` {#max.graph.quantization.QuantizationEncoding.Q6_K} > Q6\_K 6-bit K-quantization with 256 elements per block. ### `GPTQ` {#max.graph.quantization.QuantizationEncoding.GPTQ} > GPTQ Group-wise Post-Training Quantization for large language models. ### `block_parameters` {#max.graph.quantization.QuantizationEncoding.block_parameters} > property block\_parameters: [BlockParameters](#max.graph.quantization.BlockParameters) Gets the block parameters for this quantization encoding.
**Returns:**
The parameters describing how elements are organized and encoded in blocks for this quantization encoding.
**Return type:**
[BlockParameters](#max.graph.quantization.BlockParameters)
### `block_size` {#max.graph.quantization.QuantizationEncoding.block_size} > property block\_size: [int](https://docs.python.org/3/library/functions.html#int) Number of bytes in encoded representation of block. All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of bytes resulting after encoding a single block.
**Returns:**
Size in bytes of each encoded quantization block.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `elements_per_block` {#max.graph.quantization.QuantizationEncoding.elements_per_block} > property elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) Number of elements per block. All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of elements gathered into a block.
**Returns:**
Number of original tensor elements in each quantized block.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `is_gguf` {#max.graph.quantization.QuantizationEncoding.is_gguf} > property is\_gguf: [bool](https://docs.python.org/3/library/functions.html#bool) Checks if this quantization encoding is compatible with GGUF format. GGUF is a format for storing large language models and compatible quantized weights.
**Returns:**
True if this encoding is compatible with GGUF, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `name` {#max.graph.quantization.QuantizationEncoding.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) Gets the lowercase name of the quantization encoding.
**Returns:**
Lowercase string representation of the quantization encoding.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
--- ## shape Library for graph shape types. ## `Shape` {#max.graph.shape.Shape} > class max.graph.shape.Shape(dims=())
**Parameters:**
dims (ShapeLike)
### `from_mlir()` {#max.graph.shape.Shape.from_mlir} > classmethod from\_mlir(attr)
**Parameters:**
attr (TypedAttr)
**Return type:**
[Shape](#max.graph.shape.Shape)
### `is_static()` {#max.graph.shape.Shape.is_static} > static is\_static(shape)
**Parameters:**
shape ([Shape](#max.graph.shape.Shape))
**Return type:**
[TypeGuard](https://docs.python.org/3/library/typing.html#typing.TypeGuard)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[StaticDim](dim.md#max.graph.dim.StaticDim)]]
### `parameters` {#max.graph.shape.Shape.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](dim.md#max.graph.dim.SymbolicDim)] Lists the symbolic dimension names on which this shape depends. ### `rank` {#max.graph.shape.Shape.rank} > property rank ### `static_dims` {#max.graph.shape.Shape.static_dims} > property static\_dims: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Returns all static dims in the shape as a list of integers. ### `to_mlir()` {#max.graph.shape.Shape.to_mlir} > to\_mlir()
**Return type:**
ShapeAttr
--- ## type Library for graph value types. ## `BufferType` {#max.graph.type.BufferType} > class max.graph.type.BufferType(dtype, shape, device) A symbolic buffer type. This is a reference to a tensor that can be mutated in place.
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType)) * shape ([Shape](shape.md#max.graph.shape.Shape)) * device ([DeviceRef](#max.graph.type.DeviceRef))
### `as_tensor()` {#max.graph.type.BufferType.as_tensor} > as\_tensor() Returns the analogous tensor type.
**Return type:**
[TensorType](#max.graph.type.TensorType)
### `from_mlir()` {#max.graph.type.BufferType.from_mlir} > classmethod from\_mlir(type) Constructs a buffer type from an MLIR type.
**Parameters:**
* t – The MLIR Type object to parse into a buffer type. * type (BufferType)
**Returns:**
The buffer type represented by the MLIR Type value.
**Return type:**
[BufferType](#max.graph.type.BufferType)
### `to_mlir()` {#max.graph.type.BufferType.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
BufferType
## `ConvInputLayout` {#max.graph.type.ConvInputLayout} > class max.graph.type.ConvInputLayout(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `NCHW` {#max.graph.type.ConvInputLayout.NCHW} > NCHW = 'NCHW' ### `NHWC` {#max.graph.type.ConvInputLayout.NHWC} > NHWC = 'NHWC' ### `from_mlir()` {#max.graph.type.ConvInputLayout.from_mlir} > static from\_mlir(attr) Constructs a layout from an attribute.
**Parameters:**
attr (StringAttr) – The MLIR Attribute object to parse into a layout.
**Returns:**
The FilterLayout represented by the Attribute value.
**Return type:**
[ConvInputLayout](#max.graph.type.ConvInputLayout)
### `to_mlir()` {#max.graph.type.ConvInputLayout.to_mlir} > to\_mlir() Returns an mlir Attribute representing this Layout. This attribute is used for certain convolution ops.
**Returns:**
An Attribute representing the layout.
**Return type:**
StringAttr
## `DeviceKind` {#max.graph.type.DeviceKind} > class max.graph.type.DeviceKind(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) A device type representation. ### `CPU` {#max.graph.type.DeviceKind.CPU} > CPU = 'cpu' ### `GPU` {#max.graph.type.DeviceKind.GPU} > GPU = 'gpu' ### `from_string()` {#max.graph.type.DeviceKind.from_string} > static from\_string(txt)
**Parameters:**
txt ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[DeviceKind](#max.graph.type.DeviceKind)
## `DeviceRef` {#max.graph.type.DeviceRef} > class max.graph.type.DeviceRef(device\_type, id=0) A symbolic device representation. DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in mlir. The following example demonstrates how to create and use device references: ```python from max.graph import DeviceRef gpu_device = DeviceRef.GPU() print(gpu_device) # Outputs: gpu:0 # Create a CPU device with specific id cpu_device = DeviceRef.CPU(id=1) print(cpu_device) # Outputs: cpu:1 ```
**Parameters:**
* device\_type ([DeviceKind](#max.graph.type.DeviceKind)) * id ([int](https://docs.python.org/3/library/functions.html#int))
### `CPU()` {#max.graph.type.DeviceRef.CPU} > static CPU(id=0) Static Method for creating a CPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `GPU()` {#max.graph.type.DeviceRef.GPU} > static GPU(id=0) Static Method for creating a GPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `device_type` {#max.graph.type.DeviceRef.device_type} > device\_type: [DeviceKind](#max.graph.type.DeviceKind) ### `from_device()` {#max.graph.type.DeviceRef.from_device} > static from\_device(device)
**Parameters:**
device ([Device](../driver.md#max.driver.Device) | [DeviceRef](#max.graph.type.DeviceRef))
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `from_mlir()` {#max.graph.type.DeviceRef.from_mlir} > static from\_mlir(attr) Returns a device from an mlir attribute
**Parameters:**
attr (DeviceRefAttr)
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `id` {#max.graph.type.DeviceRef.id} > id: [int](https://docs.python.org/3/library/functions.html#int) ### `is_cpu()` {#max.graph.type.DeviceRef.is_cpu} > is\_cpu() Returns true if the device is a CPU device.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `is_gpu()` {#max.graph.type.DeviceRef.is_gpu} > is\_gpu() Returns true if the device is a GPU device.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `to_device()` {#max.graph.type.DeviceRef.to_device} > to\_device() Convert device reference to a concrete driver Device.
**Return type:**
[Device](../driver.md#max.driver.Device)
### `to_mlir()` {#max.graph.type.DeviceRef.to_mlir} > to\_mlir() Returns a mlir attribute representing device.
**Return type:**
DeviceRefAttr
## `FilterLayout` {#max.graph.type.FilterLayout} > class max.graph.type.FilterLayout(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `CFRS` {#max.graph.type.FilterLayout.CFRS} > CFRS = 'CFRS' ### `FCQRS` {#max.graph.type.FilterLayout.FCQRS} > FCQRS = 'FCQRS' ### `FCRS` {#max.graph.type.FilterLayout.FCRS} > FCRS = 'FCRS' ### `QRSCF` {#max.graph.type.FilterLayout.QRSCF} > QRSCF = 'QRSCF' ### `RSCF` {#max.graph.type.FilterLayout.RSCF} > RSCF = 'RSCF' ### `from_mlir()` {#max.graph.type.FilterLayout.from_mlir} > static from\_mlir(attr) Constructs a layout from an attribute.
**Parameters:**
attr (LayoutAttr) – The MLIR Attribute object to parse into a layout.
**Returns:**
The FilterLayout represented by the Attribute value.
**Return type:**
[FilterLayout](#max.graph.type.FilterLayout)
### `to_mlir()` {#max.graph.type.FilterLayout.to_mlir} > to\_mlir() Returns an mlir Attribute representing this Layout. This attribute is used in tensor type metadata for certain ops.
**Returns:**
An Attribute representing the layout.
**Return type:**
LayoutAttr
## `TensorType` {#max.graph.type.TensorType} > class max.graph.type.TensorType(dtype, shape, device, \_layout=None) A symbolic [`TensorType`](#max.graph.type.TensorType). This is not an eager tensor type! This contains no actual data, but instead represents the type of a value at some point in time during model execution. Most internal values in a model will be tensors. This type represents their element type (`dtype`) and dimensions (`dims`) at a specific point during model computation. It allows us to do some optimistic optimizations and shape inference during graph construction, and to provide more detailed shape information to the compiler for further optimization passes. The following example shows how to create a tensor type with static dimensions and access its properties: ```python from max.graph import TensorType from max.dtype import DType # Create a tensor type with float32 elements and static dimensions 2x3 tensor_type = TensorType(DType.float32, (2, 3)) print(tensor_type.dtype) # Outputs: DType.float32 print(tensor_type.shape) # Outputs: [2, 3] ``` It can also represent a fully dynamic rank tensor. The presence of dynamic rank tensors in a graph will often degrade performance dramatically and prevents many classes of optimizations. An optional device (`device`) can also be provided to indicate the explicit device the tensor is associated with.
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType)) * shape ([Shape](shape.md#max.graph.shape.Shape)) * device ([DeviceRef](#max.graph.type.DeviceRef)) * \_layout ([FilterLayout](#max.graph.type.FilterLayout) | None)
### `as_buffer()` {#max.graph.type.TensorType.as_buffer} > as\_buffer() Returns the analogous buffer type.
**Return type:**
[BufferType](#max.graph.type.BufferType)
### `from_mlir()` {#max.graph.type.TensorType.from_mlir} > classmethod from\_mlir(type) Constructs a tensor type from an MLIR type.
**Parameters:**
* t – The MLIR Type object to parse into a tensor type. * type (TensorType)
**Returns:**
The tensor type represented by the MLIR Type value.
**Return type:**
[TensorType](#max.graph.type.TensorType)
### `to_mlir()` {#max.graph.type.TensorType.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
TensorType
## `Type` {#max.graph.type.Type} > class max.graph.type.Type Represents any possible type for Graph values. Every Value in the Graph has a Type, and that type is represented by an Type. This type may be inspected to get finer-grained types and learn more about an individual Value. The following example shows how to work with types in a graph: ```python from max.graph import Graph, TensorType from max.dtype import DType with Graph() as g: # Create a tensor constant with a specific type tensor_type = TensorType(DType.float32, [2, 3]) # The type can be inspected to get information about the value print(f"Tensor element type: {tensor_type.dtype}") # Outputs: DType.float32 print(f"Tensor shape: {tensor_type.shape}") # Outputs: [2, 3] ``` ### `from_mlir()` {#max.graph.type.Type.from_mlir} > static from\_mlir(t) Constructs a type from an MLIR type.
**Parameters:**
t (MlirType) – The MLIR Type object to parse into a type.
**Returns:**
The type represented by the MLIR Type value.
**Return type:**
[Type](#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]
### `to_mlir()` {#max.graph.type.Type.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
MlirType
--- ## weights Weights are the learned parameters that store a neural network’s knowledge. They’re multi-dimensional arrays (tensors) of numerical values that determine how the model transforms inputs into outputs. These weights contain all the information needed for a model to perform its task - whether that’s text generation, image classification, or any other capability. ## `GGUFWeights` {#max.graph.weights.GGUFWeights} > class max.graph.weights.GGUFWeights(source, tensors=None, prefix='', allocated=None) Implementation for loading weights from GGUF (GPT-Generated Unified Format) files. `GGUFWeights` provides an interface to load model weights from GGUF files, which are optimized for quantized large language models. GGUF is the successor to GGML format and is commonly used in the `llama.cpp` ecosystem for efficient storage and loading of quantized models. ```python from pathlib import Path from max.graph.weights import GGUFWeights from max.dtype import DType from max.graph.quantization import QuantizationEncoding gguf_path = Path("model-q4_k.gguf") weights = GGUFWeights(gguf_path) # Check if a weight exists if weights.model.layers[0].attention.wq.exists(): # Allocate quantized attention weight wq_weight = weights.model.layers[0].attention.wq.allocate( dtype=DType.uint8, # GGUF quantized weights use uint8 device=DeviceRef.CPU() ) # Access weight data with quantization info weight_data = weights.model.layers[0].attention.wq.data() print(f"Quantization: {weight_data.quantization_encoding}") print(f"Shape: {weight_data.shape}") # Allocate with quantization validation ffn_weight = weights.model.layers[0].feed_forward.w1.allocate( quantization_encoding=QuantizationEncoding.Q4_K, device=DeviceRef.GPU(0) ) # Iterate through all weights in a layer for name, weight in weights.model.layers[0].items(): if weight.exists(): print(f"Found weight: {name}") ```
**Parameters:**
* source (PathLike\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | gguf.GGUFReader) * tensors ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), gguf.ReaderTensor] | None) * prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) * allocated ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] | None)
### `allocate()` {#max.graph.weights.GGUFWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates and optionally validates a new Weight.
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType) | None) * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) * quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None) * device ([DeviceRef](type.md#max.graph.type.DeviceRef))
**Return type:**
[Weight](Weight.md#max.graph.Weight)
### `allocated_weights` {#max.graph.weights.GGUFWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.GGUFWeights.data} > data() Get weight data with metadata. ```python weight_data = weights.model.embeddings.weight.data() print(f"Shape: {weight_data.shape}") print(f"Dtype: {weight_data.dtype}") # Convert to different dtype fp16_data = weight_data.astype(DType.float16) ```
**Returns:**
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.GGUFWeights.exists} > exists() Check if a weight with this exact name exists. ```python if weights.model.classifier.weight.exists(): classifier = weights.model.classifier.weight.allocate(...) else: print("Classifier weight not found") ```
**Returns:**
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `items()` {#max.graph.weights.GGUFWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.GGUFWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ## `SafetensorWeights` {#max.graph.weights.SafetensorWeights} > class max.graph.weights.SafetensorWeights(filepaths, \*, tensors=None, tensors\_to\_file\_idx=None, prefix='', allocated=None, \_st\_weight\_map=None, \_st\_file\_handles=None) Implementation for loading weights from safetensors files. SafetensorWeights provides a secure and efficient way to load model weights from safetensors format files. Safetensors is designed by Hugging Face for safe serialization that prevents arbitrary code execution and supports memory-mapped loading for fast access. ```python from pathlib import Path from max.graph.weights import SafetensorWeights from max.dtype import DType # Load weights from safetensors files weight_files = [Path("model.safetensors")] weights = SafetensorWeights(weight_files) # Check if a weight exists if weights.model.embeddings.weight.exists(): # Allocate the embedding weight embedding_weight = weights.model.embeddings.weight.allocate( dtype=DType.float32, device=DeviceRef.CPU() ) # Access weights with hierarchical naming attn_weight = weights.transformer.layers[0].attention.weight.allocate( dtype=DType.float16 ) ```
**Parameters:**
* filepaths (Sequence\[PathLike\[[str](https://docs.python.org/3/library/stdtypes.html#str)]]) * tensors (Set\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | None) * tensors\_to\_file\_idx (Mapping\[[str](https://docs.python.org/3/library/stdtypes.html#str), [int](https://docs.python.org/3/library/functions.html#int)] | None) * prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) * allocated ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] | None) * \_st\_weight\_map ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Buffer](../driver.md#max.driver.Buffer)]) * \_st\_file\_handles ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[PathLike](https://docs.python.org/3/library/os.html#os.PathLike)\[[str](https://docs.python.org/3/library/stdtypes.html#str)], SafeTensor])
### `allocate()` {#max.graph.weights.SafetensorWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates a Weight that can be added to a graph.
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType) | None) * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) * quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None) * device ([DeviceRef](type.md#max.graph.type.DeviceRef))
**Return type:**
[Weight](Weight.md#max.graph.Weight)
### `allocate_as_bytes()` {#max.graph.weights.SafetensorWeights.allocate_as_bytes} > allocate\_as\_bytes(dtype=None) Create a Weight that can be added to the graph. Has a uint8 representation, instead of the original data type. Last dimension of the scale gets scaled by number of bytes it takes to represent the original data type. For example, \[512, 256] float32 weights become \[512, 1024] uint8 weights. Scalar weights will be interpreted as weights with shape \[1].
**Parameters:**
dtype ([DType](../dtype.md#max.dtype.DType) | None)
**Return type:**
[Weight](Weight.md#max.graph.Weight)
### `allocated_weights` {#max.graph.weights.SafetensorWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.SafetensorWeights.data} > data() Get weight data with metadata. ```python weight_data = weights.model.embeddings.weight.data() print(f"Shape: {weight_data.shape}") print(f"Dtype: {weight_data.dtype}") # Convert to different dtype fp16_data = weight_data.astype(DType.float16) ```
**Returns:**
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.SafetensorWeights.exists} > exists() Check if a weight with this exact name exists. ```python if weights.model.classifier.weight.exists(): classifier = weights.model.classifier.weight.allocate(...) else: print("Classifier weight not found") ```
**Returns:**
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `items()` {#max.graph.weights.SafetensorWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.SafetensorWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ## `WeightData` {#max.graph.weights.WeightData} > class max.graph.weights.WeightData(data, name, dtype, shape, quantization\_encoding=None) Container for weight tensor data with metadata. `WeightData` encapsulates a weight tensor along with its metadata, providing utilities for type conversion and format compatibility. It supports the DLPack protocol for efficient tensor sharing between frameworks.
**Parameters:**
* data ([DLPackArray](../driver.md#max.driver.DLPackArray)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * dtype ([DType](../dtype.md#max.dtype.DType)) * shape ([Shape](shape.md#max.graph.shape.Shape)) * quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None)
### `astype()` {#max.graph.weights.WeightData.astype} > astype(dtype) Convert the weight data to a different dtype. This method performs actual data conversion, unlike `view()` which reinterprets the underlying bytes. Special handling is provided for bfloat16 conversions using PyTorch when available. ```python # Convert float32 weights to float16 for reduced memory weight_data = weights.model.layer.weight.data() fp16_data = weight_data.astype(DType.float16) ```
**Parameters:**
dtype ([DType](../dtype.md#max.dtype.DType)) – Target data type for conversion.
**Returns:**
A new WeightData instance with the converted data.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `data` {#max.graph.weights.WeightData.data} > data: [DLPackArray](../driver.md#max.driver.DLPackArray) The weight tensor as a DLPack array. ### `dtype` {#max.graph.weights.WeightData.dtype} > dtype: [DType](../dtype.md#max.dtype.DType) Data type of the tensor (for example, `DType.float32`, `DType.uint8`). ### `from_numpy()` {#max.graph.weights.WeightData.from_numpy} > classmethod from\_numpy(arr, name) Create WeightData from a numpy array.
**Parameters:**
* arr ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Numpy array containing the weight data. * name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Name to assign to this weight.
**Returns:**
A new WeightData instance with dtype and shape inferred from the numpy array.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `name` {#max.graph.weights.WeightData.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) Hierarchical name of the weight (for example, `"model.layers.0.weight"`). ### `quantization_encoding` {#max.graph.weights.WeightData.quantization_encoding} > quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) = None Optional quantization scheme applied to the weight. ### `shape` {#max.graph.weights.WeightData.shape} > shape: [Shape](shape.md#max.graph.shape.Shape) Shape of the tensor as a Shape object. ## `Weights` {#max.graph.weights.Weights} > class max.graph.weights.Weights(\*args, \*\*kwargs) Protocol for managing and accessing model weights hierarchically. The Weights protocol provides a convenient interface for loading and organizing neural network weights. It supports hierarchical naming through attribute and index access, making it easy to work with complex model architectures. Weights in MAX are tensors backed by external memory (buffers or memory-mapped files) that remain separate from the compiled graph. ```python from max.graph import Graph from max.dtype import DType # Create a graph and get its weights interface graph = Graph("my_model") weights = graph.weights() # Allocate weights with hierarchical naming attn_weight = weights.transformer.layers[0].attention.weight.allocate( dtype=DType.float32, shape=(768, 768) ) # Creates weight named "transformer.layers.0.attention.weight" # Check if a weight exists before allocating if weights.transformer.layers[0].mlp.weight.exists(): mlp_weight = weights.transformer.layers[0].mlp.weight.allocate( dtype=DType.float16, shape=(768, 3072) ) ``` ### `allocate()` {#max.graph.weights.Weights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Create a Weight object for this tensor. ```python # Allocate a weight with specific configuration weight = weights.model.layers[0].weight.allocate( dtype=DType.float16, # Convert to half precision shape=(768, 768), device=DeviceRef.GPU(0) # Place on first GPU ) # Add to graph with graph: weight_tensor = graph.add_weight(weight) ```
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType) | None) – Data type for the weight. If `None`, uses the original dtype. * shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – Shape of the weight tensor. If `None`, uses the original shape. * quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None) – Quantization scheme to apply (for example, `Q4_K`, `Q8_0`). * device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – Target device for the weight (CPU or GPU).
**Returns:**
A Weight object that can be added to a graph using `graph.add_weight()`.
**Return type:**
[Weight](Weight.md#max.graph.Weight)
### `allocated_weights` {#max.graph.weights.Weights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] Get all previously allocated weights. This only includes weights that were explicitly allocated : using the [`allocate()`](#max.graph.weights.Weights.allocate) method, not all available weights.
**Returns:**
A dictionary mapping weight names to their numpy arrays for all weights that have been allocated through this interface.
### `data()` {#max.graph.weights.Weights.data} > data() Get weight data with metadata. ```python weight_data = weights.model.embeddings.weight.data() print(f"Shape: {weight_data.shape}") print(f"Dtype: {weight_data.dtype}") # Convert to different dtype fp16_data = weight_data.astype(DType.float16) ```
**Returns:**
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.Weights.exists} > exists() Check if a weight with this exact name exists. ```python if weights.model.classifier.weight.exists(): classifier = weights.model.classifier.weight.allocate(...) else: print("Classifier weight not found") ```
**Returns:**
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `items()` {#max.graph.weights.Weights.items} > items() Iterate through all weights that start with the current prefix. ```python # Iterate through all weights in a specific layer for name, weight in weights.transformer.layers[0].items(): print(f"Found weight: {name}") ```
**Yields:**
Tuples of (name, weight\_accessor) for each weight under the current prefix. The name is relative to the current prefix.
**Parameters:**
self (\_Self)
**Return type:**
[Iterator](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterator)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), \_Self]]
### `name` {#max.graph.weights.Weights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) Get the current weight name or prefix.
**Returns:**
The hierarchical name built from attribute and index access. For example, if accessed as `weights.model.layers[0]`, returns “model.layers.0”.
## `WeightsFormat` {#max.graph.weights.WeightsFormat} > class max.graph.weights.WeightsFormat(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enumeration of supported weight file formats. MAX supports multiple weight formats to accommodate different model sources and use cases. ### `gguf` {#max.graph.weights.WeightsFormat.gguf} > gguf = 'gguf' GGUF (GPT-Generated Unified Format) for quantized models. File extension: `.gguf` Optimized for quantized large language models, particularly those from the llama.cpp ecosystem. Supports multiple quantization schemes (`Q4_K`, `Q5_K`, `Q8_0`, etc.) and includes model metadata in the file. ### `safetensors` {#max.graph.weights.WeightsFormat.safetensors} > safetensors = 'safetensors' Safetensors format for secure and efficient tensor storage. File extension: `.safetensors` Designed by Hugging Face for safe serialization that prevents arbitrary code execution. Uses memory-mapped files for fast loading and supports sharding across multiple files. ## `load_weights()` {#max.graph.weights.load_weights} > max.graph.weights.load\_weights(paths) Loads neural network weights from checkpoint files. Automatically detects checkpoint formats based on file extensions and returns the appropriate Weights implementation, creating a seamless interface for loading weights from different formats. Supported formats: * Safetensors: .safetensors * PyTorch: .bin, .pt, .pth * GGUF: .gguf The following example shows how to load weights from a Safetensors file: ```python from pathlib import Path from max.graph.weights import load_weights # Load multi-file checkpoints sharded_paths = [ Path("model-00001-of-00003.safetensors"), Path("model-00002-of-00003.safetensors"), Path("model-00003-of-00003.safetensors") ] weights = load_weights(sharded_paths) layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate( dtype=DType.float32, shape=[4096, 14336], device=DeviceRef.GPU(0) ) ```
**Parameters:**
paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) – List of pathlib.Path objects pointing to checkpoint files. For multi-file checkpoints (e.g., sharded Safetensors), provide all file paths in the list. For single-file checkpoints, provide a list with one path.
**Return type:**
[Weights](#max.graph.weights.Weights)
## `weights_format()` {#max.graph.weights.weights_format} > max.graph.weights.weights\_format(weight\_paths) Detect the format of weight files based on their extensions. This function examines the file extensions of all provided paths to determine the weight format. All files must have the same format; mixed formats are not supported. ```python from pathlib import Path # Detect format for safetensor files paths = [Path("model-00001.safetensors"), Path("model-00002.safetensors")] format = weights_format(paths) print(format) # WeightsFormat.safetensors ```
**Parameters:**
weight\_paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) – List of file paths containing model weights. All files must have the same extension/format.
**Returns:**
The detected WeightsFormat enum value.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If weight\_paths is empty, contains mixed formats, or has unsupported file extensions.
**Return type:**
[WeightsFormat](#max.graph.weights.WeightsFormat)
--- ## max The MAX Python API reference. The MAX API provides a high-performance graph compiler and runtime library that executes AI models with incredible speed on a wide range of hardware. MAX offers a layered architecture that lets you work at the level of abstraction that best fits your needs. From deploying production-ready models with a few lines of code to building custom neural networks from scratch, each layer builds upon the others so you can move between levels seamlessly as requirements evolve. For an introduction, see the [Model developer guide](/max/develop/). ## Packages and modules * [`diagnostics.gpu`](/max/api/python/diagnostics/gpu): GPU monitoring and performance diagnostics utilities. * [`driver`](/max/api/python/driver): Low-level device management and tensor operations. * [`dtype`](/max/api/python/dtype): Unified data type system supporting various numeric formats. * [`engine`](/max/api/python/engine): Model execution runtime with automatic optimization. * [`entrypoints`](/max/api/python/entrypoints): Command-line tools and serving infrastructure. * [`functional`](/max/api/python/functional): Functional tensor operations (relu, softmax, etc.). * [`graph`](/max/api/python/graph): Computational graph construction with 100+ operations for complete model control. * [`interfaces`](/max/api/python/interfaces): Universal interfaces for consistent API integration. * [`kv_cache`](/max/api/python/kv_cache): KV cache management for efficient attention computation. * [`nn`](/max/api/python/nn): High-level neural network building blocks with automatic graph compilation. * [`pipelines`](/max/api/python/pipelines): Pre-built, optimized model architectures for immediate deployment. * [`profiler`](/max/api/python/profiler): Performance profiling and tracing utilities. * [`random`](/max/api/python/random): Random tensor generation utilities. * [`tensor`](/max/api/python/tensor): Tensor class with eager execution. * [`torch`](/max/api/python/torch): PyTorch integration for custom operations and interoperability. --- ## interfaces Universal interfaces between all aspects of the MAX Inference Stack. ## `AudioGenerationInputs` {#max.interfaces.AudioGenerationInputs} > class max.interfaces.AudioGenerationInputs(batch) Input data structure for audio generation pipelines. This class represents the input data required for audio generation operations within the pipeline framework. It extends PipelineInputs and provides type-safe generic support for different audio generation context types.
**Parameters:**
batch ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), AudioGenerationContextType])
### `batch` {#max.interfaces.AudioGenerationInputs.batch} > batch: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), AudioGenerationContextType] A dictionary mapping RequestID to AudioGenerationContextType instances. This batch structure allows for processing multiple audio generation requests simultaneously while maintaining request-specific context and configuration data. ## `AudioGenerationMetadata` {#max.interfaces.AudioGenerationMetadata} > class max.interfaces.AudioGenerationMetadata(\*, sample\_rate=None, duration=None, chunk\_id=None, timestamp=None, final\_chunk=None, model\_name=None, request\_id=None, tokens\_generated=None, processing\_time=None, echo=None) Represents metadata associated with audio generation. This class will eventually replace the metadata dictionary used throughout the AudioGenerationOutput object, providing a structured and type-safe alternative for audio generation metadata.
**Parameters:**
* sample\_rate ([int](https://docs.python.org/3/library/functions.html#int) | None) – The sample rate of the generated audio in Hz. * duration ([float](https://docs.python.org/3/library/functions.html#float) | None) – The duration of the generated audio in seconds. * chunk\_id ([int](https://docs.python.org/3/library/functions.html#int) | None) – Identifier for the audio chunk (useful for streaming). * timestamp ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Timestamp when the audio was generated. * final\_chunk ([bool](https://docs.python.org/3/library/functions.html#bool) | None) – Whether this is the final chunk in a streaming sequence. * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Name of the model used for generation. * request\_id ([RequestID](#max.interfaces.RequestID) | None) – Unique identifier for the generation request. * tokens\_generated ([int](https://docs.python.org/3/library/functions.html#int) | None) – Number of tokens generated for this audio. * processing\_time ([float](https://docs.python.org/3/library/functions.html#float) | None) – Time taken to process this audio chunk in seconds. * echo ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Echo of the input prompt or identifier for verification.
### `chunk_id` {#max.interfaces.AudioGenerationMetadata.chunk_id} > chunk\_id: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `duration` {#max.interfaces.AudioGenerationMetadata.duration} > duration: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) ### `echo` {#max.interfaces.AudioGenerationMetadata.echo} > echo: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `final_chunk` {#max.interfaces.AudioGenerationMetadata.final_chunk} > final\_chunk: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) ### `model_name` {#max.interfaces.AudioGenerationMetadata.model_name} > model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `processing_time` {#max.interfaces.AudioGenerationMetadata.processing_time} > processing\_time: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) ### `request_id` {#max.interfaces.AudioGenerationMetadata.request_id} > request\_id: [RequestID](#max.interfaces.RequestID) | [None](https://docs.python.org/3/library/constants.html#None) ### `sample_rate` {#max.interfaces.AudioGenerationMetadata.sample_rate} > sample\_rate: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `timestamp` {#max.interfaces.AudioGenerationMetadata.timestamp} > timestamp: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `to_dict()` {#max.interfaces.AudioGenerationMetadata.to_dict} > to\_dict() Convert the metadata to a dictionary format.
**Returns:**
Dictionary representation of the metadata.
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), any]
### `tokens_generated` {#max.interfaces.AudioGenerationMetadata.tokens_generated} > tokens\_generated: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ## `AudioGenerationOutput` {#max.interfaces.AudioGenerationOutput} > class max.interfaces.AudioGenerationOutput(final\_status, steps\_executed, audio\_data=\, buffer\_speech\_tokens=None, metadata=\) Represents a response from the audio generation API. This class encapsulates the result of an audio generation request, including the final status, generated audio data, and optional buffered speech tokens.
**Parameters:**
* final\_status ([GenerationStatus](#max.interfaces.GenerationStatus)) * steps\_executed ([int](https://docs.python.org/3/library/functions.html#int)) * audio\_data ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * buffer\_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None) * metadata ([AudioGenerationMetadata](#max.interfaces.AudioGenerationMetadata))
### `audio_data` {#max.interfaces.AudioGenerationOutput.audio_data} > audio\_data: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] The generated audio data, if available. ### `buffer_speech_tokens` {#max.interfaces.AudioGenerationOutput.buffer_speech_tokens} > buffer\_speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) Buffered speech tokens, if available. ### `final_status` {#max.interfaces.AudioGenerationOutput.final_status} > final\_status: [GenerationStatus](#max.interfaces.GenerationStatus) The final status of the generation process. ### `is_done` {#max.interfaces.AudioGenerationOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the audio generation process is complete.
**Returns:**
`True` if generation is done, `False` otherwise.
**Return type:**
[`bool`](https://docs.python.org/3/library/functions.html#bool)
### `metadata` {#max.interfaces.AudioGenerationOutput.metadata} > metadata: [AudioGenerationMetadata](#max.interfaces.AudioGenerationMetadata) Metadata associated with the audio generation, such as chunk information, prompt details, or other relevant context. ### `steps_executed` {#max.interfaces.AudioGenerationOutput.steps_executed} > steps\_executed: [int](https://docs.python.org/3/library/functions.html#int) The number of steps previously executed. ## `AudioGenerationRequest` {#max.interfaces.AudioGenerationRequest} > class max.interfaces.AudioGenerationRequest(request\_id: 'RequestID', model: 'str', input: 'str | None' = None, audio\_prompt\_tokens: 'list\[int]' = \, audio\_prompt\_transcription: 'str' = '', sampling\_params: 'SamplingParams' = \, \_assistant\_message\_override: 'str | None' = None, prompt: 'list\[int] | str | None' = None, streaming: 'bool' = True, buffer\_speech\_tokens: 'npt.NDArray\[np.integer\[Any]] | None' = None)
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * model ([str](https://docs.python.org/3/library/stdtypes.html#str)) * input ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * audio\_prompt\_tokens ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * audio\_prompt\_transcription ([str](https://docs.python.org/3/library/stdtypes.html#str)) * sampling\_params ([SamplingParams](#max.interfaces.SamplingParams)) * \_assistant\_message\_override ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * prompt ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [str](https://docs.python.org/3/library/stdtypes.html#str) | None) * streaming ([bool](https://docs.python.org/3/library/functions.html#bool)) * buffer\_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None)
### `audio_prompt_tokens` {#max.interfaces.AudioGenerationRequest.audio_prompt_tokens} > audio\_prompt\_tokens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] The prompt speech IDs to use for audio generation. ### `audio_prompt_transcription` {#max.interfaces.AudioGenerationRequest.audio_prompt_transcription} > audio\_prompt\_transcription: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' The audio prompt transcription to use for audio generation. ### `buffer_speech_tokens` {#max.interfaces.AudioGenerationRequest.buffer_speech_tokens} > buffer\_speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None An optional field potentially containing the last N speech tokens generated by the model from a previous request. When this field is specified, this tensor is used to buffer the tokens sent to the audio decoder. ### `input` {#max.interfaces.AudioGenerationRequest.input} > input: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None The text to generate audio for. The maximum length is 4096 characters. ### `model` {#max.interfaces.AudioGenerationRequest.model} > model: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the model to be used for generating audio chunks. This should match the available models on the server and determines the behavior and capabilities of the response generation. ### `prompt` {#max.interfaces.AudioGenerationRequest.prompt} > prompt: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None Optionally provide a preprocessed list of token ids or a prompt string to pass as input directly into the model. This replaces automatically generating TokenGeneratorRequestMessages given the input, audio prompt tokens, audio prompt transcription fields. ### `sampling_params` {#max.interfaces.AudioGenerationRequest.sampling_params} > sampling\_params: [SamplingParams](#max.interfaces.SamplingParams) Request sampling configuration options. ### `streaming` {#max.interfaces.AudioGenerationRequest.streaming} > streaming: [bool](https://docs.python.org/3/library/functions.html#bool) = True Whether to stream the audio generation. ## `BaseContext` {#max.interfaces.BaseContext} > class max.interfaces.BaseContext(\*args, \*\*kwargs) Core interface for request lifecycle management across all of MAX, including serving, scheduling, and pipelines. This protocol is intended to provide a unified, minimal contract for request state and status handling throughout the MAX stack. Each pipeline variant (e.g., text generation, embeddings, image generation) is expected to extend this interface by creating their own modality-specific context classes that implement this protocol and add additional functionality relevant to their particular use case. The minimal interface ensures that all context types can be handled uniformly by the scheduling and serving infrastructure, while allowing pipeline-specific implementations to add their own state management, input validation, and result handling. ### `is_done` {#max.interfaces.BaseContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the request has completed generation. ### `request_id` {#max.interfaces.BaseContext.request_id} > property request\_id: [RequestID](#max.interfaces.RequestID) Unique identifier for the request. ### `status` {#max.interfaces.BaseContext.status} > property status: [GenerationStatus](#max.interfaces.GenerationStatus) Current generation status of the request. ## `BatchProcessorInputs` {#max.interfaces.BatchProcessorInputs} > class max.interfaces.BatchProcessorInputs(logits, logit\_offsets, context\_batch) Arguments for a batch logits processor. * logits: The model logits, a float32 tensor with shape (N\_batch, vocab\_size). N\_batch is the number of logits returned by the model for each sequence in the batch. * logit\_offsets: If the model returns multiple logits, this is a tensor with shape (batch\_size + 1, 1) that contains the offsets of each sequence in the batch. Otherwise, this is None. * context\_batch: The batch of contexts containing the inputs to the model.
**Parameters:**
* logits (md.Buffer) * logit\_offsets (md.Buffer | None) * context\_batch (Sequence\[[TextGenerationContext](#max.interfaces.TextGenerationContext)])
### `context_batch` {#max.interfaces.BatchProcessorInputs.context_batch} > context\_batch: Sequence\[[TextGenerationContext](#max.interfaces.TextGenerationContext)] ### `logit_offsets` {#max.interfaces.BatchProcessorInputs.logit_offsets} > logit\_offsets: md.Buffer | [None](https://docs.python.org/3/library/constants.html#None) ### `logits` {#max.interfaces.BatchProcessorInputs.logits} > logits: md.Buffer ## `BatchType` {#max.interfaces.BatchType} > class max.interfaces.BatchType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Type of batch. ### `CE` {#max.interfaces.BatchType.CE} > CE = 'CE' Context encoding batch. ### `TG` {#max.interfaces.BatchType.TG} > TG = 'TG' Token generation batch. ## `EmbeddingsContext` {#max.interfaces.EmbeddingsContext} > class max.interfaces.EmbeddingsContext(\*args, \*\*kwargs) Protocol defining the interface for embeddings generation contexts. An `EmbeddingsContext` represents model inputs for embeddings generation pipelines, managing the state and parameters needed for generating embeddings from input text. Unlike text generation contexts, this focuses on single-step embedding generation without iterative token generation concerns. This protocol includes only the fields necessary for embeddings generation, excluding text generation specific features like: * End-of-sequence token handling (eos\_token\_ids) * Grammar matchers for structured output (matcher) * JSON schema constraints (json\_schema) * Log probability tracking (log\_probabilities) * Token generation iteration state ### `model_name` {#max.interfaces.EmbeddingsContext.model_name} > property model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the embeddings model to use.
**Returns:**
A string identifying the specific embeddings model for this request.
### `tokens` {#max.interfaces.EmbeddingsContext.tokens} > property tokens: [TokenBuffer](#max.interfaces.TokenBuffer) The input tokens to be embedded.
**Returns:**
A NumPy array of token IDs representing the input text to generate embeddings for.
## `EmbeddingsGenerationInputs` {#max.interfaces.EmbeddingsGenerationInputs} > class max.interfaces.EmbeddingsGenerationInputs(batches: 'list\[dict\[RequestID, EmbeddingsContext]]')
**Parameters:**
batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), [EmbeddingsContext](#max.interfaces.EmbeddingsContext)]])
### `batch` {#max.interfaces.EmbeddingsGenerationInputs.batch} > property batch: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), [EmbeddingsContext](#max.interfaces.EmbeddingsContext)] Returns merged batches. ### `batches` {#max.interfaces.EmbeddingsGenerationInputs.batches} > batches: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), [EmbeddingsContext](#max.interfaces.EmbeddingsContext)]] ## `EmbeddingsGenerationOutput` {#max.interfaces.EmbeddingsGenerationOutput} > class max.interfaces.EmbeddingsGenerationOutput(embeddings) Response structure for embedding generation.
**Parameters:**
embeddings ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – The generated embeddings as a NumPy array.
### `embeddings` {#max.interfaces.EmbeddingsGenerationOutput.embeddings} > embeddings: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] The generated embeddings as a NumPy array. ### `is_done` {#max.interfaces.EmbeddingsGenerationOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the embedding generation process is complete.
**Returns:**
Always True, as embedding generation is a single-step operation.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `GenerationOutput` {#max.interfaces.GenerationOutput} > class max.interfaces.GenerationOutput(\*, request\_id, final\_status, output) Output container for image generation pipeline operations. This class holds a list of generated images in OpenResponses API format, along with request tracking and status information. It implements the PipelineOutput protocol by providing the required is\_done property. Example: ```python import numpy as np from max.interfaces.generation import GenerationOutput from max.interfaces.request import RequestID from max.interfaces.request.open_responses import OutputImageContent from max.interfaces.status import GenerationStatus img_array1 = np.random.rand(512, 512, 3).astype(np.float32) img_array2 = np.random.rand(512, 512, 3).astype(np.float32) result = GenerationOutput( request_id=RequestID(value="req-123"), final_status=GenerationStatus.END_OF_SEQUENCE, output=[ OutputImageContent.from_numpy(img_array1, format="png"), OutputImageContent.from_numpy(img_array2, format="jpeg"), ] ) # Or create from URLs result_from_urls = GenerationOutput( request_id=RequestID(value="req-456"), final_status=GenerationStatus.END_OF_SEQUENCE, output=[ OutputImageContent( type="output_image", image_url="https://example.com/image1.png", format="png" ) ] ) # Check if generation is complete if result.is_done: print(f"Generated {len(result.output)} images") ```
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * final\_status ([GenerationStatus](#max.interfaces.GenerationStatus)) * output ([list](https://docs.python.org/3/library/stdtypes.html#list)\[OutputImageContent])
### `final_status` {#max.interfaces.GenerationOutput.final_status} > final\_status: [GenerationStatus](#max.interfaces.GenerationStatus) The final status of the generation process. ### `is_done` {#max.interfaces.GenerationOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the pipeline operation has completed.
**Returns:**
True if the generation is done (status is not ACTIVE), False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `output` {#max.interfaces.GenerationOutput.output} > output: [list](https://docs.python.org/3/library/stdtypes.html#list)\[OutputImageContent] List of OutputImageContent objects representing generated images. ### `request_id` {#max.interfaces.GenerationOutput.request_id} > request\_id: [RequestID](#max.interfaces.RequestID) The unique identifier for the generation request. ## `GenerationStatus` {#max.interfaces.GenerationStatus} > class max.interfaces.GenerationStatus(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enum representing the status of a generation process in the MAX API. ### `ACTIVE` {#max.interfaces.GenerationStatus.ACTIVE} > ACTIVE = 'active' The generation process is ongoing. ### `CANCELLED` {#max.interfaces.GenerationStatus.CANCELLED} > CANCELLED = 'cancelled' The generation process has been cancelled by the user. ### `END_OF_SEQUENCE` {#max.interfaces.GenerationStatus.END_OF_SEQUENCE} > END\_OF\_SEQUENCE = 'end\_of\_sequence' The generation process has reached the end of the sequence. ### `MAXIMUM_LENGTH` {#max.interfaces.GenerationStatus.MAXIMUM_LENGTH} > MAXIMUM\_LENGTH = 'maximum\_length' The generation process has reached the maximum allowed length. ### `is_done` {#max.interfaces.GenerationStatus.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Returns True if the generation process is complete (not ACTIVE).
**Returns:**
True if the status is not ACTIVE, indicating completion.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `ImageContentPart` {#max.interfaces.ImageContentPart} > class max.interfaces.ImageContentPart(\*, type='image')
**Parameters:**
type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['image'])
### `model_config` {#max.interfaces.ImageContentPart.model_config} > model\_config: ClassVar\[ConfigDict] = {'frozen': True} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `type` {#max.interfaces.ImageContentPart.type} > type: Literal\['image'] ## `ImageMetadata` {#max.interfaces.ImageMetadata} > class max.interfaces.ImageMetadata(\*, start\_idx, end\_idx, pixel\_values, image\_hash=None) Metadata about an image in the prompt. Each image corresponds to a range in the text token array \[start\_idx, end\_idx).
**Parameters:**
* start\_idx ([int](https://docs.python.org/3/library/functions.html#int)) * end\_idx ([int](https://docs.python.org/3/library/functions.html#int)) * pixel\_values ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * image\_hash ([int](https://docs.python.org/3/library/functions.html#int) | None)
### `end_idx` {#max.interfaces.ImageMetadata.end_idx} > end\_idx: [int](https://docs.python.org/3/library/functions.html#int) One after the index of the last \ special token for the image ### `image_hash` {#max.interfaces.ImageMetadata.image_hash} > image\_hash: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Hash of the image, for use in prefix caching ### `pixel_values` {#max.interfaces.ImageMetadata.pixel_values} > pixel\_values: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] Pixel values for the image. Can be various dtypes depending on the vision model: * float32: Original precision * uint16: BFloat16 bits stored as uint16 (workaround for NumPy’s lack of native bfloat16 support). Reinterpreted as bfloat16 on GPU. ### `start_idx` {#max.interfaces.ImageMetadata.start_idx} > start\_idx: [int](https://docs.python.org/3/library/functions.html#int) Index of the first \ special token for the image ## `LoRAOperation` {#max.interfaces.LoRAOperation} > class max.interfaces.LoRAOperation(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enum for different LoRA operations. ### `LOAD` {#max.interfaces.LoRAOperation.LOAD} > LOAD = 'load' ### `UNLOAD` {#max.interfaces.LoRAOperation.UNLOAD} > UNLOAD = 'unload' ## `LoRARequest` {#max.interfaces.LoRARequest} > class max.interfaces.LoRARequest(operation, lora\_name, lora\_path=None) Container for LoRA adapter requests.
**Parameters:**
* operation ([LoRAOperation](#max.interfaces.LoRAOperation)) * lora\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * lora\_path ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `lora_name` {#max.interfaces.LoRARequest.lora_name} > lora\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `lora_path` {#max.interfaces.LoRARequest.lora_path} > lora\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `operation` {#max.interfaces.LoRARequest.operation} > operation: [LoRAOperation](#max.interfaces.LoRAOperation) ## `LoRAResponse` {#max.interfaces.LoRAResponse} > class max.interfaces.LoRAResponse(status, message) Response from LoRA operations.
**Parameters:**
* status ([LoRAStatus](#max.interfaces.LoRAStatus)) * message ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)])
### `message` {#max.interfaces.LoRAResponse.message} > message: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] ### `status` {#max.interfaces.LoRAResponse.status} > status: [LoRAStatus](#max.interfaces.LoRAStatus) ## `LoRAStatus` {#max.interfaces.LoRAStatus} > class max.interfaces.LoRAStatus(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enum for LoRA operation status. ### `LOAD_ERROR` {#max.interfaces.LoRAStatus.LOAD_ERROR} > LOAD\_ERROR = 'load\_error' ### `LOAD_INVALID_ADAPTER` {#max.interfaces.LoRAStatus.LOAD_INVALID_ADAPTER} > LOAD\_INVALID\_ADAPTER = 'load\_invalid\_adapter' ### `LOAD_INVALID_PATH` {#max.interfaces.LoRAStatus.LOAD_INVALID_PATH} > LOAD\_INVALID\_PATH = 'load\_invalid\_path' ### `LOAD_NAME_EXISTS` {#max.interfaces.LoRAStatus.LOAD_NAME_EXISTS} > LOAD\_NAME\_EXISTS = 'load\_name\_exists' ### `SUCCESS` {#max.interfaces.LoRAStatus.SUCCESS} > SUCCESS = 'success' ### `UNLOAD_ERROR` {#max.interfaces.LoRAStatus.UNLOAD_ERROR} > UNLOAD\_ERROR = 'unload\_error' ### `UNLOAD_NAME_NONEXISTENT` {#max.interfaces.LoRAStatus.UNLOAD_NAME_NONEXISTENT} > UNLOAD\_NAME\_NONEXISTENT = 'unload\_name\_nonexistent' ### `UNSPECIFIED_ERROR` {#max.interfaces.LoRAStatus.UNSPECIFIED_ERROR} > UNSPECIFIED\_ERROR = 'unspecified\_error' ## `LoRAType` {#max.interfaces.LoRAType} > class max.interfaces.LoRAType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enumeration for LoRA Types. ### `A` {#max.interfaces.LoRAType.A} > A = 'lora\_A' Represents the LoRA A matrix (high rank tensor to low rank tensor). ### `B` {#max.interfaces.LoRAType.B} > B = 'lora\_B' Represents the LoRA B matrix. (low rank tensor to high rank tensor) ### `BIAS` {#max.interfaces.LoRAType.BIAS} > BIAS = 'lora.bias' Represents the LoRA bias matrix. (added to matrix B) ### `B_KV` {#max.interfaces.LoRAType.B_KV} > B\_KV = 'lora\_B\_kv' Represents the combined K and V LoRA B matrices for QKV fusion. ## `LogProbabilities` {#max.interfaces.LogProbabilities} > class max.interfaces.LogProbabilities(token\_log\_probabilities, top\_log\_probabilities) Log probabilities for an individual output token. This is a data-only class that serves as a serializable data structure for transferring log probability information. It does not provide any functionality for calculating or manipulating log probabilities - it is purely for data storage and serialization purposes.
**Parameters:**
* token\_log\_probabilities ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]) * top\_log\_probabilities ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [float](https://docs.python.org/3/library/functions.html#float)]])
### `token_log_probabilities` {#max.interfaces.LogProbabilities.token_log_probabilities} > token\_log\_probabilities: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)] Probabilities of each token. ### `top_log_probabilities` {#max.interfaces.LogProbabilities.top_log_probabilities} > top\_log\_probabilities: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [float](https://docs.python.org/3/library/functions.html#float)]] Top tokens and their corresponding probabilities. ## `MAXPullQueue` {#max.interfaces.MAXPullQueue} > class max.interfaces.MAXPullQueue(\*args, \*\*kwargs) Protocol for a minimal, non-blocking pull queue interface in MAX. This protocol defines the contract for a queue that supports non-blocking get operations for retrieving items. It is generic over the item type and designed for scenarios where the caller must be immediately notified if no items are available rather than waiting for items to arrive. The protocol is intended for consumer-side queue operations where immediate feedback about queue state is critical for proper flow control and error handling. ### `get_nowait()` {#max.interfaces.MAXPullQueue.get_nowait} > get\_nowait() Remove and return an item from the queue without blocking. This method is expected to raise queue.Empty if no item is available to retrieve from the queue.
**Returns:**
The item removed from the queue.
**Return type:**
PullItemType
**Raises:**
[queue.Empty](https://docs.python.org/3/library/queue.html#queue.Empty) – If the queue is empty and no item can be retrieved.
## `MAXPushQueue` {#max.interfaces.MAXPushQueue} > class max.interfaces.MAXPushQueue(\*args, \*\*kwargs) Protocol for a minimal, non-blocking push queue interface in MAX. This protocol defines the contract for a queue that supports non-blocking put operations for adding items. It is generic over the item type and designed for scenarios where the caller must be immediately notified of success or failure rather than waiting for space to become available. The protocol is intended for producer-side queue operations where immediate feedback is critical for proper flow control and error handling. ### `put_nowait()` {#max.interfaces.MAXPushQueue.put_nowait} > put\_nowait(item) Attempt to put an item into the queue without blocking. This method is designed to immediately fail (typically by raising an exception) if the item cannot be added to the queue at the time of the call. Unlike the traditional ‘put’ method in many queue implementations—which may block until space becomes available or the transfer is completed—this method never waits. It is intended for use cases where the caller must be notified of failure to enqueue immediately, rather than waiting for space.
**Parameters:**
item (PushItemType) – The item to be added to the queue.
**Return type:**
None
## `OpenResponsesRequest` {#max.interfaces.OpenResponsesRequest} > class max.interfaces.OpenResponsesRequest(request\_id, body) General request container for OpenResponses API requests. This class wraps an OpenResponsesRequestBody and adheres to the Request schema. All request fields are accessed directly from the body.
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * body (OpenResponsesRequestBody)
### `body` {#max.interfaces.OpenResponsesRequest.body} > body: OpenResponsesRequestBody The complete OpenResponses request body. ### `from_fastapi_request()` {#max.interfaces.OpenResponsesRequest.from_fastapi_request} > async classmethod from\_fastapi\_request(request) Create an OpenResponsesRequest from a FastAPI/Starlette Request. Extracts the request\_id from request.state.request\_id and parses the request body as an OpenResponsesRequestBody.
**Parameters:**
request (FastAPIRequestProtocol) – A request object with state.request\_id and body() method. Compatible with FastAPI/Starlette Request objects.
**Returns:**
An OpenResponsesRequest instance.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If request.state.request\_id is not set. * pydantic.ValidationError – If the request body is invalid.
**Return type:**
[OpenResponsesRequest](#max.interfaces.OpenResponsesRequest)
## `Pipeline` {#max.interfaces.Pipeline} > class max.interfaces.Pipeline Abstract base class for pipeline operations. This generic abstract class defines the interface for pipeline operations that transform inputs of type PipelineInputsType into outputs of type PipelineOutputsDict\[PipelineOutputType]. All concrete pipeline implementations must inherit from this class and implement the execute method. Type Parameters: : PipelineInputsType: The type of inputs this pipeline accepts, must inherit from PipelineInputs PipelineOutputType: The type of outputs this pipeline produces, must be a subclass of PipelineOutput ```python class MyPipeline(Pipeline[MyInputs, MyOutput]): def execute(self, inputs: MyInputs) -> dict[RequestID, MyOutput]: # Implementation here pass ``` ### `execute()` {#max.interfaces.Pipeline.execute} > abstract execute(inputs) Execute the pipeline operation with the given inputs. This method must be implemented by all concrete pipeline classes. It takes inputs of the specified type and returns outputs according to the pipeline’s processing logic.
**Parameters:**
inputs (PipelineInputsType) – The input data for the pipeline operation, must be of type PipelineInputsType
**Returns:**
The results of the pipeline operation, as a dictionary mapping RequestID to PipelineOutputType
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If not implemented by a concrete subclass
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), PipelineOutputType]
### `release()` {#max.interfaces.Pipeline.release} > abstract release(request\_id) Release any resources or state associated with a specific request. This method should be implemented by concrete pipeline classes to perform cleanup or resource deallocation for the given request ID. It is typically called when a request has completed processing and its associated resources (such as memory, cache, or temporary files) are no longer needed.
**Parameters:**
request\_id ([RequestID](#max.interfaces.RequestID)) – The unique identifier of the request to release resources for.
**Returns:**
None
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If not implemented by a concrete subclass.
**Return type:**
None
## `PipelineInputs` {#max.interfaces.PipelineInputs} > class max.interfaces.PipelineInputs Base class representing inputs to a pipeline operation. This class serves as a marker interface for all pipeline input types. Concrete implementations should inherit from this class and define the specific input data structures required for their pipeline operations. ```python class MyPipelineInputs(PipelineInputs): def __init__(self, data: str, config: dict): self.data = data self.config = config ``` ## `PipelineOutput` {#max.interfaces.PipelineOutput} > class max.interfaces.PipelineOutput(\*args, \*\*kwargs) Protocol representing the output of a pipeline operation. Subclasses must implement the is\_done property to indicate whether the pipeline operation has completed. ### `is_done` {#max.interfaces.PipelineOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the pipeline operation has completed.
**Returns:**
True if the operation is done, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `PipelineTask` {#max.interfaces.PipelineTask} > class max.interfaces.PipelineTask(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enum representing the types of pipeline tasks supported. ### `AUDIO_GENERATION` {#max.interfaces.PipelineTask.AUDIO_GENERATION} > AUDIO\_GENERATION = 'audio\_generation' Task for generating audio. ### `EMBEDDINGS_GENERATION` {#max.interfaces.PipelineTask.EMBEDDINGS_GENERATION} > EMBEDDINGS\_GENERATION = 'embeddings\_generation' Task for generating embeddings. ### `PIXEL_GENERATION` {#max.interfaces.PipelineTask.PIXEL_GENERATION} > PIXEL\_GENERATION = 'pixel\_generation' Task for generating pixels. ### `SPEECH_TOKEN_GENERATION` {#max.interfaces.PipelineTask.SPEECH_TOKEN_GENERATION} > SPEECH\_TOKEN\_GENERATION = 'speech\_token\_generation' Task for generating speech tokens. ### `TEXT_GENERATION` {#max.interfaces.PipelineTask.TEXT_GENERATION} > TEXT\_GENERATION = 'text\_generation' Task for generating text. ### `output_type` {#max.interfaces.PipelineTask.output_type} > property output\_type: [type](https://docs.python.org/3/library/functions.html#type)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), [SchedulerResult](#max.interfaces.SchedulerResult)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Get the output type for the pipeline task.
**Returns:**
The output type for the pipeline task.
**Return type:**
[type](https://docs.python.org/3/library/functions.html#type)
## `PipelineTokenizer` {#max.interfaces.PipelineTokenizer} > class max.interfaces.PipelineTokenizer(\*args, \*\*kwargs) Interface for LLM tokenizers. ### `decode()` {#max.interfaces.PipelineTokenizer.decode} > async decode(encoded, \*\*kwargs) Decodes response tokens to text.
**Parameters:**
encoded (TokenizerEncoded) – Encoded response tokens.
**Returns:**
Un-encoded response text.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `encode()` {#max.interfaces.PipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens) Encodes text prompts as tokens.
**Parameters:**
* prompt ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Un-encoded prompt text. * add\_special\_tokens ([bool](https://docs.python.org/3/library/functions.html#bool))
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the prompt exceeds the configured maximum length.
**Return type:**
TokenizerEncoded
### `eos` {#max.interfaces.PipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.interfaces.PipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to be wrapped as a dict. Text messages are formatted as: ```json { "role": "user", "content": [{ "type": "text", "text": "text content" }] } ``` instead of: ```json { "role": "user", "content": "text_content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `RequestType.images`. ### `new_context()` {#max.interfaces.PipelineTokenizer.new_context} > async new\_context(request) Creates a new context from a request object. This is sent to the worker process once and then cached locally.
**Parameters:**
request (RequestType) – Incoming request.
**Returns:**
Initialized context.
**Return type:**
UnboundContextType
## `PipelinesFactory` {#max.interfaces.PipelinesFactory} > max.interfaces.PipelinesFactory Type alias for factory functions that create pipeline instances. Factory functions should return a Pipeline with properly typed inputs and outputs that are bound to the PipelineInputs and PipelineOutput base classes respectively. This ensures type safety while maintaining flexibility for different pipeline implementations. **Example:** def create\_text\_pipeline() -> Pipeline\[TextGenerationInputs, TextGenerationOutput]: : return MyTextGenerationPipeline() factory: PipelinesFactory = create\_text\_pipeline alias of [`Callable`](graph/ops.md#max.graph.ops.Callable)\[\[], [`Pipeline`](#max.interfaces.Pipeline)\[`PipelineInputsType`, `PipelineOutputType`]] ## `PixelGenerationContext` {#max.interfaces.PixelGenerationContext} > class max.interfaces.PixelGenerationContext(\*args, \*\*kwargs) Protocol defining the interface for pixel generation contexts. A `PixelGenerationContext` represents model inputs for pixel generation pipelines, managing the state and parameters needed for generating images or videos. ### `guidance_scale` {#max.interfaces.PixelGenerationContext.guidance_scale} > property guidance\_scale: [float](https://docs.python.org/3/library/functions.html#float) Classifier-free guidance scale (1.0 to disable CFG). ### `height` {#max.interfaces.PixelGenerationContext.height} > property height: [int](https://docs.python.org/3/library/functions.html#int) Height of generated output in pixels. ### `latents` {#max.interfaces.PixelGenerationContext.latents} > property latents: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] The latents for the context. ### `num_images_per_prompt` {#max.interfaces.PixelGenerationContext.num_images_per_prompt} > property num\_images\_per\_prompt: [int](https://docs.python.org/3/library/functions.html#int) Number of images to generate. ### `num_inference_steps` {#max.interfaces.PixelGenerationContext.num_inference_steps} > property num\_inference\_steps: [int](https://docs.python.org/3/library/functions.html#int) Number of denoising steps. ### `tokens` {#max.interfaces.PixelGenerationContext.tokens} > property tokens: [TokenBuffer](#max.interfaces.TokenBuffer) The token buffer for the context. ### `width` {#max.interfaces.PixelGenerationContext.width} > property width: [int](https://docs.python.org/3/library/functions.html#int) Width of generated output in pixels. ## `PixelGenerationInputs` {#max.interfaces.PixelGenerationInputs} > class max.interfaces.PixelGenerationInputs(batch) Input data structure for pixel generation pipelines. This class represents the input data required for pixel generation operations within the pipeline framework. It extends PipelineInputs and provides type-safe generic support for different pixel generation context types.
**Parameters:**
batch ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), PixelGenerationContextType])
### `batch` {#max.interfaces.PixelGenerationInputs.batch} > batch: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), PixelGenerationContextType] A dictionary mapping RequestID to PixelGenerationContextType instances. This batch structure allows for processing multiple pixel generation requests simultaneously while maintaining request-specific context and configuration data. ## `PixelGenerationOutput` {#max.interfaces.PixelGenerationOutput} > class max.interfaces.PixelGenerationOutput(request\_id, final\_status, pixel\_data=\) Represents a response from the pixel generation API. This class encapsulates the result of a pixel generation request, including the request ID, final status, and generated pixel data.
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * final\_status ([GenerationStatus](#max.interfaces.GenerationStatus)) * pixel\_data ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]])
### `final_status` {#max.interfaces.PixelGenerationOutput.final_status} > final\_status: [GenerationStatus](#max.interfaces.GenerationStatus) The final status of the generation process. ### `is_done` {#max.interfaces.PixelGenerationOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the pixel generation process is complete.
**Returns:**
True if the generation is done, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `pixel_data` {#max.interfaces.PixelGenerationOutput.pixel_data} > pixel\_data: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] The generated pixel data, if available. ### `request_id` {#max.interfaces.PixelGenerationOutput.request_id} > request\_id: [RequestID](#max.interfaces.RequestID) The unique identifier for the generation request. ## `ProcessorInputs` {#max.interfaces.ProcessorInputs} > class max.interfaces.ProcessorInputs(logits: 'md.Buffer', context: 'TextGenerationContext')
**Parameters:**
* logits (md.Buffer) * context ([TextGenerationContext](#max.interfaces.TextGenerationContext))
### `context` {#max.interfaces.ProcessorInputs.context} > context: [TextGenerationContext](#max.interfaces.TextGenerationContext) ### `logits` {#max.interfaces.ProcessorInputs.logits} > logits: md.Buffer ## `Request` {#max.interfaces.Request} > class max.interfaces.Request(request\_id) Base class representing a generic request within the MAX API. This class provides a unique identifier for each request, ensuring that all requests can be tracked and referenced consistently throughout the system. Subclasses can extend this class to include additional fields specific to their request types.
**Parameters:**
request\_id ([RequestID](#max.interfaces.RequestID))
### `request_id` {#max.interfaces.Request.request_id} > request\_id: [RequestID](#max.interfaces.RequestID) ## `RequestID` {#max.interfaces.RequestID} > class max.interfaces.RequestID(value=\) A unique immutable identifier for a request. When instantiated without arguments, automatically generates a new UUID4-based ID.
**Parameters:**
value ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The string identifier. If not provided, generates a UUID4 hex string.
### `value` {#max.interfaces.RequestID.value} > value: [str](https://docs.python.org/3/library/stdtypes.html#str) ## `SamplingParams` {#max.interfaces.SamplingParams} > class max.interfaces.SamplingParams(top\_k=-1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=\, logits\_processors=None) Request specific sampling parameters that are only known at run time.
**Parameters:**
* top\_k ([int](https://docs.python.org/3/library/functions.html#int)) * top\_p ([float](https://docs.python.org/3/library/functions.html#float)) * min\_p ([float](https://docs.python.org/3/library/functions.html#float)) * temperature ([float](https://docs.python.org/3/library/functions.html#float)) * frequency\_penalty ([float](https://docs.python.org/3/library/functions.html#float)) * presence\_penalty ([float](https://docs.python.org/3/library/functions.html#float)) * repetition\_penalty ([float](https://docs.python.org/3/library/functions.html#float)) * max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * min\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int)) * ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool)) * stop ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | None) * stop\_token\_ids ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) * detokenize ([bool](https://docs.python.org/3/library/functions.html#bool)) * seed ([int](https://docs.python.org/3/library/functions.html#int)) * logits\_processors ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], None]] | None)
### `detokenize` {#max.interfaces.SamplingParams.detokenize} > detokenize: [bool](https://docs.python.org/3/library/functions.html#bool) = True Whether to detokenize the output tokens into text. ### `frequency_penalty` {#max.interfaces.SamplingParams.frequency_penalty} > frequency\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances. ### `from_input_and_generation_config()` {#max.interfaces.SamplingParams.from_input_and_generation_config} > classmethod from\_input\_and\_generation\_config(input\_params, sampling\_params\_defaults) Create SamplingParams with defaults from HuggingFace’s GenerationConfig. This method creates a SamplingParams instance by combining three sources of values, in priority order (highest to lowest): 1. User-provided values in input\_params (non-None) 2. Model’s GenerationConfig values (only if explicitly set in the model’s config) 3. SamplingParams class defaults
**Parameters:**
* input\_params ([SamplingParamsInput](#max.interfaces.SamplingParamsInput)) – Dataclass containing user-specified parameter values. Values of None will be replaced with model defaults or class defaults. * sampling\_params\_defaults ([SamplingParamsGenerationConfigDefaults](#max.interfaces.SamplingParamsGenerationConfigDefaults)) – SamplingParamsGenerationConfigDefaults containing default sampling parameters extracted from the model’s GenerationConfig.
**Returns:**
A new SamplingParams instance with model-aware defaults.
**Return type:**
[SamplingParams](#max.interfaces.SamplingParams)
**Example:** ```pycon >>> sampling_defaults = model_config.sampling_params_defaults >>> params = SamplingParams.from_input_and_generation_config( ... SamplingParamsInput(temperature=0.7), # User override ... sampling_params_defaults=sampling_defaults ... ) ``` ### `ignore_eos` {#max.interfaces.SamplingParams.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False If True, the response will ignore the EOS token, and continue to generate until the max tokens or a stop string is hit. ### `log_sampling_info()` {#max.interfaces.SamplingParams.log_sampling_info} > log\_sampling\_info() Log comprehensive sampling parameters information. Displays all sampling parameters in a consistent visual format similar to pipeline configuration logging.
**Return type:**
None
### `logits_processors` {#max.interfaces.SamplingParams.logits_processors} > logits\_processors: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], [None](https://docs.python.org/3/library/constants.html#None)]] | [None](https://docs.python.org/3/library/constants.html#None) = None Callables to post-process the model logits. See `LogitsProcessor` for examples. ### `max_new_tokens` {#max.interfaces.SamplingParams.max_new_tokens} > max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None The maximum number of new tokens to generate in the response. When set to an integer value, generation will stop after this many tokens. When None (default), the model may generate tokens until it reaches its internal limits or other stopping criteria are met. ### `min_new_tokens` {#max.interfaces.SamplingParams.min_new_tokens} > min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = 0 The minimum number of tokens to generate in the response. ### `min_p` {#max.interfaces.SamplingParams.min_p} > min\_p: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in \[0, 1]. Set to 0 to disable this. ### `presence_penalty` {#max.interfaces.SamplingParams.presence_penalty} > presence\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 The presence penalty to apply to the model’s output. A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty. ### `repetition_penalty` {#max.interfaces.SamplingParams.repetition_penalty} > repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 1.0 The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in the generated text at least once by dividing the logits by the repetition penalty. ### `seed` {#max.interfaces.SamplingParams.seed} > seed: [int](https://docs.python.org/3/library/functions.html#int) The seed to use for the random number generator. Defaults to a cryptographically secure random value. ### `stop` {#max.interfaces.SamplingParams.stop} > stop: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of detokenized sequences that can be used as stop criteria when generating a new sequence. ### `stop_token_ids` {#max.interfaces.SamplingParams.stop_token_ids} > stop\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of token ids that are used as stopping criteria when generating a new sequence. ### `temperature` {#max.interfaces.SamplingParams.temperature} > temperature: [float](https://docs.python.org/3/library/functions.html#float) = 1 Controls the randomness of the model’s output; higher values produce more diverse responses. For greedy sampling, set to temperature to 0. ### `top_k` {#max.interfaces.SamplingParams.top_k} > top\_k: [int](https://docs.python.org/3/library/functions.html#int) = -1 Limits the sampling to the K most probable tokens. This defaults to -1 (to sample all tokens), for greedy sampling set to 1. ### `top_p` {#max.interfaces.SamplingParams.top_p} > top\_p: [float](https://docs.python.org/3/library/functions.html#float) = 1 Only use the tokens whose cumulative probability is within the top\_p threshold. This applies to the top\_k tokens. ## `SamplingParamsGenerationConfigDefaults` {#max.interfaces.SamplingParamsGenerationConfigDefaults} > class max.interfaces.SamplingParamsGenerationConfigDefaults(temperature=None, top\_p=None, top\_k=None, repetition\_penalty=None, max\_new\_tokens=None, min\_new\_tokens=None, do\_sample=None) Default sampling parameter values extracted from a model’s GenerationConfig. This class encapsulates sampling parameter defaults that come from a HuggingFace model’s GenerationConfig. These defaults have middle priority when creating SamplingParams instances: Priority order (highest to lowest): 1. User-provided values (SamplingParamsInput) 2. Model’s GenerationConfig values (this class) 3. SamplingParams class defaults All fields default to None, indicating that the model’s GenerationConfig does not explicitly set that parameter. When None, SamplingParams will fall back to its own class defaults. **Example:** ```pycon >>> # Extract from model config >>> gen_config = model_config.generation_config >>> defaults = SamplingParamsGenerationConfigDefaults( ... temperature=0.7, ... top_k=50, ... max_new_tokens=512 ... ) >>> # Use with SamplingParams >>> params = SamplingParams.from_input_and_generation_config( ... SamplingParamsInput(), ... sampling_params_defaults=defaults ... ) ```
**Parameters:**
* temperature ([float](https://docs.python.org/3/library/functions.html#float) | None) * top\_p ([float](https://docs.python.org/3/library/functions.html#float) | None) * top\_k ([int](https://docs.python.org/3/library/functions.html#int) | None) * repetition\_penalty ([float](https://docs.python.org/3/library/functions.html#float) | None) * max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * min\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * do\_sample ([bool](https://docs.python.org/3/library/functions.html#bool) | None)
### `do_sample` {#max.interfaces.SamplingParamsGenerationConfigDefaults.do_sample} > do\_sample: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None If False, use greedy sampling. ### `max_new_tokens` {#max.interfaces.SamplingParamsGenerationConfigDefaults.max_new_tokens} > max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Maximum number of new tokens from the model’s GenerationConfig, if explicitly set. ### `min_new_tokens` {#max.interfaces.SamplingParamsGenerationConfigDefaults.min_new_tokens} > min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Minimum number of new tokens from the model’s GenerationConfig, if explicitly set. ### `repetition_penalty` {#max.interfaces.SamplingParamsGenerationConfigDefaults.repetition_penalty} > repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None Repetition penalty value from the model’s GenerationConfig, if explicitly set. ### `temperature` {#max.interfaces.SamplingParamsGenerationConfigDefaults.temperature} > temperature: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None Temperature value from the model’s GenerationConfig, if explicitly set. ### `top_k` {#max.interfaces.SamplingParamsGenerationConfigDefaults.top_k} > top\_k: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Top-k sampling value from the model’s GenerationConfig, if explicitly set. ### `top_p` {#max.interfaces.SamplingParamsGenerationConfigDefaults.top_p} > top\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None Top-p (nucleus sampling) value from the model’s GenerationConfig, if explicitly set. ### `values_to_update` {#max.interfaces.SamplingParamsGenerationConfigDefaults.values_to_update} > property values\_to\_update: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [float](https://docs.python.org/3/library/functions.html#float) | [int](https://docs.python.org/3/library/functions.html#int)] Extract non-None field values as a dictionary.
**Returns:**
A dictionary mapping field names to their values, excluding any fields that are None. This dictionary can be used to update SamplingParams default values.
**Example:** ```pycon >>> defaults = SamplingParamsGenerationConfigDefaults( ... temperature=0.7, ... top_k=50 ... ) >>> defaults.values_to_update {'temperature': 0.7, 'top_k': 50} ``` ## `SamplingParamsInput` {#max.interfaces.SamplingParamsInput} > class max.interfaces.SamplingParamsInput(top\_k=None, top\_p=None, min\_p=None, temperature=None, frequency\_penalty=None, presence\_penalty=None, repetition\_penalty=None, max\_new\_tokens=None, min\_new\_tokens=None, ignore\_eos=None, stop=None, stop\_token\_ids=None, detokenize=None, seed=None, logits\_processors=None) Input dataclass for creating SamplingParams instances. All fields are optional, allowing partial specification with None values indicating “use default”. This enables static type checking while maintaining the flexibility to specify only the parameters you want to override.
**Parameters:**
* top\_k ([int](https://docs.python.org/3/library/functions.html#int) | None) * top\_p ([float](https://docs.python.org/3/library/functions.html#float) | None) * min\_p ([float](https://docs.python.org/3/library/functions.html#float) | None) * temperature ([float](https://docs.python.org/3/library/functions.html#float) | None) * frequency\_penalty ([float](https://docs.python.org/3/library/functions.html#float) | None) * presence\_penalty ([float](https://docs.python.org/3/library/functions.html#float) | None) * repetition\_penalty ([float](https://docs.python.org/3/library/functions.html#float) | None) * max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * min\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool) | None) * stop ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | None) * stop\_token\_ids ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) * detokenize ([bool](https://docs.python.org/3/library/functions.html#bool) | None) * seed ([int](https://docs.python.org/3/library/functions.html#int) | None) * logits\_processors ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], None]] | None)
### `detokenize` {#max.interfaces.SamplingParamsInput.detokenize} > detokenize: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `frequency_penalty` {#max.interfaces.SamplingParamsInput.frequency_penalty} > frequency\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `ignore_eos` {#max.interfaces.SamplingParamsInput.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `logits_processors` {#max.interfaces.SamplingParamsInput.logits_processors} > logits\_processors: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], [None](https://docs.python.org/3/library/constants.html#None)]] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `max_new_tokens` {#max.interfaces.SamplingParamsInput.max_new_tokens} > max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `min_new_tokens` {#max.interfaces.SamplingParamsInput.min_new_tokens} > min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `min_p` {#max.interfaces.SamplingParamsInput.min_p} > min\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `presence_penalty` {#max.interfaces.SamplingParamsInput.presence_penalty} > presence\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `repetition_penalty` {#max.interfaces.SamplingParamsInput.repetition_penalty} > repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `seed` {#max.interfaces.SamplingParamsInput.seed} > seed: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `stop` {#max.interfaces.SamplingParamsInput.stop} > stop: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `stop_token_ids` {#max.interfaces.SamplingParamsInput.stop_token_ids} > stop\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `temperature` {#max.interfaces.SamplingParamsInput.temperature} > temperature: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `top_k` {#max.interfaces.SamplingParamsInput.top_k} > top\_k: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `top_p` {#max.interfaces.SamplingParamsInput.top_p} > top\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `Scheduler` {#max.interfaces.Scheduler} > class max.interfaces.Scheduler Abstract base class defining the interface for schedulers. ### `run_iteration()` {#max.interfaces.Scheduler.run_iteration} > abstract run\_iteration() The core scheduler routine that creates and executes batches. This method should implement the core scheduling logic including: * Batch creation and management * Request scheduling ## `SchedulerResult` {#max.interfaces.SchedulerResult} > class max.interfaces.SchedulerResult(is\_done, result) Structure representing the result of a scheduler operation for a specific pipeline execution. This class encapsulates the outcome of a pipeline operation as managed by the scheduler, including both the execution status and any resulting data from the pipeline. The scheduler uses this structure to communicate the state of pipeline operations back to clients, whether the operation is still running, has completed successfully, or was cancelled. The generic type parameter allows this result to work with different types of pipeline outputs while maintaining type safety.
**Parameters:**
* is\_done ([bool](https://docs.python.org/3/library/functions.html#bool)) * result (PipelineOutputType | None)
### `cancelled()` {#max.interfaces.SchedulerResult.cancelled} > classmethod cancelled() Create a SchedulerResult representing a cancelled pipeline operation.
**Returns:**
A SchedulerResult that is done.
**Return type:**
[SchedulerResult](#max.interfaces.SchedulerResult)
### `create()` {#max.interfaces.SchedulerResult.create} > classmethod create(result) Create a SchedulerResult representing a pipeline operation with some result.
**Parameters:**
result (PipelineOutputType) – The pipeline output data.
**Returns:**
A SchedulerResult with a result.
**Return type:**
[SchedulerResult](#max.interfaces.SchedulerResult)
### `is_done` {#max.interfaces.SchedulerResult.is_done} > is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) The current status of the pipeline operation from the scheduler’s perspective. ### `result` {#max.interfaces.SchedulerResult.result} > result: PipelineOutputType | [None](https://docs.python.org/3/library/constants.html#None) The pipeline output data, if any. May be None for cancelled operations or during intermediate states of streaming operations. ## `SharedMemoryArray` {#max.interfaces.SharedMemoryArray} > class max.interfaces.SharedMemoryArray(name, shape, dtype) Wrapper for numpy array stored in shared memory. This class is used as a placeholder in pixel\_values during serialization. It will be encoded as a dict with \_\_shm\_\_ flag and decoded back to a numpy array.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * shape ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), ...]) * dtype ([str](https://docs.python.org/3/library/stdtypes.html#str))
## `TextContentPart` {#max.interfaces.TextContentPart} > class max.interfaces.TextContentPart(\*, type='text', text)
**Parameters:**
* type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['text']) * text ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `model_config` {#max.interfaces.TextContentPart.model_config} > model\_config: ClassVar\[ConfigDict] = {'frozen': True} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `text` {#max.interfaces.TextContentPart.text} > text: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `type` {#max.interfaces.TextContentPart.type} > type: Literal\['text'] ## `TextGenerationContext` {#max.interfaces.TextGenerationContext} > class max.interfaces.TextGenerationContext(\*args, \*\*kwargs) Protocol defining the interface for text generation contexts in token generation. A `TextGenerationContext` represents model inputs for text generation pipelines, managing the state of tokens throughout the generation process. It handles token arrays, generation status, sampling parameters, and various indices that track different stages of token processing. ### `compute_num_available_steps()` {#max.interfaces.TextGenerationContext.compute_num_available_steps} > compute\_num\_available\_steps(max\_seq\_len) Compute the maximum number of generation steps available. This method calculates how many tokens can be generated without exceeding the specified maximum sequence length limit.
**Parameters:**
max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum allowed sequence length for this context.
**Returns:**
The number of generation steps that can be executed before reaching the sequence length limit.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `eos_token_ids` {#max.interfaces.TextGenerationContext.eos_token_ids} > property eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] The set of end-of-sequence token IDs that can terminate generation.
**Returns:**
A set of token IDs that, when generated, will signal the end of the sequence and terminate the generation process.
### `get_min_token_logit_mask()` {#max.interfaces.TextGenerationContext.get_min_token_logit_mask} > get\_min\_token\_logit\_mask(num\_steps) Get token indices that should be masked in the output logits. This method is primarily used to implement the `min_tokens` constraint, where certain tokens (typically EOS tokens) are masked to prevent early termination before the minimum token count is reached.
**Parameters:**
num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – The number of generation steps to compute masks for.
**Returns:**
A list of NumPy arrays, where each array contains token indices that should be masked (set to negative infinity) in the logits for the corresponding generation step.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int32]]]
### `is_initial_prompt` {#max.interfaces.TextGenerationContext.is_initial_prompt} > property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool) Whether this context contains only the initial prompt. This property indicates if the context has not yet been updated with any generated tokens and still contains only the original input.
**Returns:**
`True` if no tokens have been generated yet, `False` if generation has begun and tokens have been added.
### `json_schema` {#max.interfaces.TextGenerationContext.json_schema} > property json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) The JSON schema for constrained decoding, if configured. When set, this schema constrains token generation to produce valid JSON output that conforms to the specified structure.
**Returns:**
The JSON schema string, or `None` if no schema constraint is active.
### `jump_ahead()` {#max.interfaces.TextGenerationContext.jump_ahead} > jump\_ahead(new\_token) Jump ahead in generation by adding a token and updating indices. This method is used in speculative decoding scenarios to quickly advance the generation state when draft tokens are accepted.
**Parameters:**
new\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add when jumping ahead in the sequence.
**Return type:**
None
### `log_probabilities` {#max.interfaces.TextGenerationContext.log_probabilities} > property log\_probabilities: [int](https://docs.python.org/3/library/functions.html#int) The number of top tokens to return log probabilities for. When greater than 0, the system returns log probabilities for the top N most likely tokens at each generation step.
**Returns:**
The number of top tokens to include in log probability output. Returns 0 if log probabilities are disabled.
### `log_probabilities_echo` {#max.interfaces.TextGenerationContext.log_probabilities_echo} > property log\_probabilities\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) Whether to include input tokens in the returned log probabilities. When `True`, log probabilities will be computed and returned for input (prompt) tokens in addition to generated tokens.
**Returns:**
`True` if input tokens should be included in log probability output, `False` otherwise.
### `matcher` {#max.interfaces.TextGenerationContext.matcher} > property matcher: [Any](https://docs.python.org/3/library/typing.html#typing.Any) | [None](https://docs.python.org/3/library/constants.html#None) The grammar matcher for structured output generation, if configured. The matcher enforces structural constraints (like JSON schema) during generation to ensure valid formatted output.
**Returns:**
The grammar matcher instance, or None if no structured generation is configured for this context.
:::note Note The matcher type depends on the structured generation backend used (e.g., outlines, guidance, etc.). In the future, this should be replaced with a Protocol for better type safety. ::: ### `max_length` {#max.interfaces.TextGenerationContext.max_length} > property max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) The maximum allowed length for this sequence. When set, generation will stop when this length is reached, regardless of other stopping criteria.
**Returns:**
The maximum sequence length limit, or `None` if no limit is set.
### `min_tokens` {#max.interfaces.TextGenerationContext.min_tokens} > property min\_tokens: [int](https://docs.python.org/3/library/functions.html#int) The minimum number of new tokens that must be generated. Generation will continue until at least this many new tokens have been produced, even if other stopping criteria are met (e.g., EOS tokens).
**Returns:**
The minimum number of new tokens to generate.
### `realize_future_token()` {#max.interfaces.TextGenerationContext.realize_future_token} > realize\_future\_token(new\_token, log\_probabilities=None) Overwrite the placeholder future token with the actual token. This is primarily used for overlap scheduling.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities ([LogProbabilities](#max.interfaces.LogProbabilities) | None)
**Return type:**
None
### `reset()` {#max.interfaces.TextGenerationContext.reset} > reset() Resets the context’s state by combining all tokens into a new prompt. This method is used when a request is evicted, meaning that the context needed to be re-encoded in the following CE iteration.
**Return type:**
None
### `sampling_params` {#max.interfaces.TextGenerationContext.sampling_params} > property sampling\_params: [SamplingParams](#max.interfaces.SamplingParams) The sampling parameters configured for this generation request. These parameters control how tokens are selected during generation, including temperature, top-k/top-p filtering, and stopping criteria.
**Returns:**
The `SamplingParams` instance containing all sampling configuration for this context.
### `set_matcher()` {#max.interfaces.TextGenerationContext.set_matcher} > set\_matcher(matcher) Set a grammar matcher for constrained decoding. This method configures structured output generation by installing a grammar matcher that enforces format constraints during token generation.
**Parameters:**
matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any)) – The grammar matcher instance to use for constraining output. The specific type depends on the structured generation backend.
**Return type:**
None
### `to_generation_output()` {#max.interfaces.TextGenerationContext.to_generation_output} > to\_generation\_output() Convert this context to a TextGenerationOutput object. This property provides a standardized way to extract the final output of the text generation process from the context, including generated text, tokens, and any associated metadata.
**Returns:**
The output object containing the results of the text generation for this context.
**Return type:**
[TextGenerationOutput](#max.interfaces.TextGenerationOutput)
### `tokens` {#max.interfaces.TextGenerationContext.tokens} > property tokens: [TokenBuffer](#max.interfaces.TokenBuffer) The token buffer for the context. ### `update()` {#max.interfaces.TextGenerationContext.update} > update(new\_token, log\_probabilities=None) Update the context with a newly generated token, and update status. This method adds a generated token to the context, updating the token array, associated metadata, and log probabilities (if provided). It is also responsible for updating the context’s generation status and determining if the generation sequence is complete, either due to reaching an end-of-sequence condition or meeting stopping criteria.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add to the generation sequence. * log\_probabilities ([LogProbabilities](#max.interfaces.LogProbabilities) | None) – Optional log probability data for the new token and alternatives. Used for analysis and debugging.
**Return type:**
None
### `update_with_future_token()` {#max.interfaces.TextGenerationContext.update_with_future_token} > update\_with\_future\_token() Append a placeholder future token to the generated tokens. This is primarily used for overlap scheduling.
**Return type:**
None
## `TextGenerationInputs` {#max.interfaces.TextGenerationInputs} > class max.interfaces.TextGenerationInputs(batches, num\_steps, input\_tokens=-1, batch\_type=BatchType.TG) Input parameters for text generation pipeline operations. This class encapsulates the batch of contexts and number of steps required for token generation in a single input object, replacing the previous pattern of passing batch and num\_steps as separate parameters.
**Parameters:**
* batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]]) * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * input\_tokens ([int](https://docs.python.org/3/library/functions.html#int)) * batch\_type ([BatchType](#max.interfaces.BatchType))
### `batch_echo` {#max.interfaces.TextGenerationInputs.batch_echo} > property batch\_echo: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[bool](https://docs.python.org/3/library/functions.html#bool)] List indicating whether echo is enabled for each context in the batch. ### `batch_top_log_probs` {#max.interfaces.TextGenerationInputs.batch_top_log_probs} > property batch\_top\_log\_probs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] List of requested top log probabilities per context in the batch. ### `batch_type` {#max.interfaces.TextGenerationInputs.batch_type} > batch\_type: [BatchType](#max.interfaces.BatchType) = 'TG' Type of batch. ### `batches` {#max.interfaces.TextGenerationInputs.batches} > batches: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]] Variable list of batches, with each batch being a list of contexts. There can be multiple batches when using data parallelism, in which each batch is mapped to a different device replica. ### `enable_echo` {#max.interfaces.TextGenerationInputs.enable_echo} > property enable\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) Return True if any context in the batch has echo enabled. ### `enable_log_probs` {#max.interfaces.TextGenerationInputs.enable_log_probs} > property enable\_log\_probs: [bool](https://docs.python.org/3/library/functions.html#bool) Return True if any context in the batch requests log probabilities. ### `flat_batch` {#max.interfaces.TextGenerationInputs.flat_batch} > property flat\_batch: [list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType] Flattened list of contexts across all replicas. ### `input_tokens` {#max.interfaces.TextGenerationInputs.input_tokens} > input\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = -1 Number of input tokens. ### `num_steps` {#max.interfaces.TextGenerationInputs.num_steps} > num\_steps: [int](https://docs.python.org/3/library/functions.html#int) Number of steps to run for. ## `TextGenerationOutput` {#max.interfaces.TextGenerationOutput} > class max.interfaces.TextGenerationOutput(\*, request\_id, tokens, final\_status, log\_probabilities=None) Represents the output of a text generation operation, combining token IDs, final generation status, request ID, and optional log probabilities for each token.
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * tokens ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * final\_status ([GenerationStatus](#max.interfaces.GenerationStatus)) * log\_probabilities ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[LogProbabilities](#max.interfaces.LogProbabilities)] | None)
### `final_status` {#max.interfaces.TextGenerationOutput.final_status} > final\_status: [GenerationStatus](#max.interfaces.GenerationStatus) The final status of the generation process. ### `is_done` {#max.interfaces.TextGenerationOutput.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether the text generation process is complete.
**Returns:**
True if the generation is done, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `log_probabilities` {#max.interfaces.TextGenerationOutput.log_probabilities} > log\_probabilities: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[LogProbabilities](#max.interfaces.LogProbabilities)] | [None](https://docs.python.org/3/library/constants.html#None) = None Optional list of log probabilities for each token. ### `request_id` {#max.interfaces.TextGenerationOutput.request_id} > request\_id: [RequestID](#max.interfaces.RequestID) The unique identifier for the generation request. ### `tokens` {#max.interfaces.TextGenerationOutput.tokens} > tokens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] List of generated token IDs. ## `TextGenerationRequest` {#max.interfaces.TextGenerationRequest} > class max.interfaces.TextGenerationRequest(request\_id: 'RequestID', model\_name: 'str', prompt: 'str | Sequence\[int] | None' = None, messages: 'list\[TextGenerationRequestMessage]' = \, images: 'list\[bytes]' = \, tools: 'list\[TextGenerationRequestTool] | None' = None, response\_format: 'TextGenerationResponseFormat | None' = None, timestamp\_ns: 'int' = 0, request\_path: 'str' = '/', logprobs: 'int' = 0, echo: 'bool' = False, stop: 'str | list\[str] | None' = None, chat\_template\_options: 'dict\[str, Any] | None' = None, sampling\_params: 'SamplingParams' = \, target\_endpoint: 'str | None' = None)
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID)) * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * prompt ([str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) * messages ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](#max.interfaces.TextGenerationRequestMessage)]) * images ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[bytes](https://docs.python.org/3/library/stdtypes.html#bytes)]) * tools ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestTool](#max.interfaces.TextGenerationRequestTool)] | None) * response\_format ([TextGenerationResponseFormat](#max.interfaces.TextGenerationResponseFormat) | None) * timestamp\_ns ([int](https://docs.python.org/3/library/functions.html#int)) * request\_path ([str](https://docs.python.org/3/library/stdtypes.html#str)) * logprobs ([int](https://docs.python.org/3/library/functions.html#int)) * echo ([bool](https://docs.python.org/3/library/functions.html#bool)) * stop ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | None) * chat\_template\_options ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None) * sampling\_params ([SamplingParams](#max.interfaces.SamplingParams)) * target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `chat_template_options` {#max.interfaces.TextGenerationRequest.chat_template_options} > chat\_template\_options: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [None](https://docs.python.org/3/library/constants.html#None) = None Optional dictionary of options to pass when applying the chat template. ### `echo` {#max.interfaces.TextGenerationRequest.echo} > echo: [bool](https://docs.python.org/3/library/functions.html#bool) = False If set to True, the response will include the original prompt along with the generated output. This can be useful for debugging or when you want to see how the input relates to the output. ### `images` {#max.interfaces.TextGenerationRequest.images} > images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[bytes](https://docs.python.org/3/library/stdtypes.html#bytes)] A list of image byte arrays that can be included as part of the request. This field is optional and may be used for multimodal inputs where images are relevant to the prompt or task. ### `logprobs` {#max.interfaces.TextGenerationRequest.logprobs} > logprobs: [int](https://docs.python.org/3/library/functions.html#int) = 0 The number of top log probabilities to return for each generated token. A value of 0 means that log probabilities will not be returned. Useful for analyzing model confidence in its predictions. ### `messages` {#max.interfaces.TextGenerationRequest.messages} > messages: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](#max.interfaces.TextGenerationRequestMessage)] A list of messages for chat-based interactions. This is used in chat completion APIs, where each message represents a turn in the conversation. If provided, the model will generate responses based on these messages. ### `model_name` {#max.interfaces.TextGenerationRequest.model_name} > model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the model to be used for generating tokens. This should match the available models on the server and determines the behavior and capabilities of the response generation. ### `number_of_images` {#max.interfaces.TextGenerationRequest.number_of_images} > property number\_of\_images: [int](https://docs.python.org/3/library/functions.html#int) Returns the total number of image-type contents across all provided messages.
**Returns:**
Total count of image-type contents found in messages.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `prompt` {#max.interfaces.TextGenerationRequest.prompt} > prompt: [str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None The prompt to be processed by the model. This field supports legacy completion APIs and can accept either a string or a sequence of integers representing token IDs. If not provided, the model may generate output based on the messages field. ### `request_path` {#max.interfaces.TextGenerationRequest.request_path} > request\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) = '/' The endpoint path for the request. This is typically used for routing and logging requests within the server infrastructure. ### `response_format` {#max.interfaces.TextGenerationRequest.response_format} > response\_format: [TextGenerationResponseFormat](#max.interfaces.TextGenerationResponseFormat) | [None](https://docs.python.org/3/library/constants.html#None) = None Specifies the desired format for the model’s output. When set, it enables structured generation, which adheres to the json\_schema provided. ### `sampling_params` {#max.interfaces.TextGenerationRequest.sampling_params} > sampling\_params: [SamplingParams](#max.interfaces.SamplingParams) Token sampling configuration parameters for the request. ### `stop` {#max.interfaces.TextGenerationRequest.stop} > stop: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None //platform.openai.com/docs/api-reference/chat/create#chat-create-stop)
**Type:**
Optional list of stop expressions (see
**Type:**
https
### `target_endpoint` {#max.interfaces.TextGenerationRequest.target_endpoint} > target\_endpoint: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None Optional target endpoint identifier for routing the request to a specific service or model instance. This should be used in disaggregate serving scenarios, when you want to dynamically route to a specific instance. If not specified, the request will be routed to the default endpoint. ### `timestamp_ns` {#max.interfaces.TextGenerationRequest.timestamp_ns} > timestamp\_ns: [int](https://docs.python.org/3/library/functions.html#int) = 0 The time (in nanoseconds) when the request was received by the server. This can be useful for performance monitoring and logging purposes. ### `tools` {#max.interfaces.TextGenerationRequest.tools} > tools: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestTool](#max.interfaces.TextGenerationRequestTool)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of tools that can be invoked during the generation process. This allows the model to utilize external functionalities or APIs to enhance its responses. ## `TextGenerationRequestFunction` {#max.interfaces.TextGenerationRequestFunction} > class max.interfaces.TextGenerationRequestFunction Represents a function definition for a text generation request. ### `description` {#max.interfaces.TextGenerationRequestFunction.description} > description: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) A human-readable description of the function’s purpose. ### `name` {#max.interfaces.TextGenerationRequestFunction.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the function to be invoked. ### `parameters` {#max.interfaces.TextGenerationRequestFunction.parameters} > parameters: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] A dictionary describing the function’s parameters, typically following a JSON schema. ## `TextGenerationRequestMessage` {#max.interfaces.TextGenerationRequestMessage} > class max.interfaces.TextGenerationRequestMessage(\*, role, content)
**Parameters:**
* role ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['system', 'user', 'assistant', 'tool', 'function']) * content ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContentPart](#max.interfaces.TextContentPart) | [ImageContentPart](#max.interfaces.ImageContentPart)])
### `content` {#max.interfaces.TextGenerationRequestMessage.content} > content: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[MessageContent] ### `flatten_content()` {#max.interfaces.TextGenerationRequestMessage.flatten_content} > flatten\_content()
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)]
### `model_config` {#max.interfaces.TextGenerationRequestMessage.model_config} > model\_config: ClassVar\[ConfigDict] = {'from\_attributes': True, 'frozen': True} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `number_of_images` {#max.interfaces.TextGenerationRequestMessage.number_of_images} > property number\_of\_images: [int](https://docs.python.org/3/library/functions.html#int) Returns the number of ImageContentPart instances in the message content. ### `role` {#max.interfaces.TextGenerationRequestMessage.role} > role: MessageRole ### `validate_content_format()` {#max.interfaces.TextGenerationRequestMessage.validate_content_format} > classmethod validate\_content\_format(v)
**Parameters:**
v ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContentPart](#max.interfaces.TextContentPart) | [ImageContentPart](#max.interfaces.ImageContentPart)]
## `TextGenerationRequestTool` {#max.interfaces.TextGenerationRequestTool} > class max.interfaces.TextGenerationRequestTool Represents a tool definition for a text generation request. ### `function` {#max.interfaces.TextGenerationRequestTool.function} > function: [TextGenerationRequestFunction](#max.interfaces.TextGenerationRequestFunction) The function definition associated with the tool, including its name, description, and parameters. ### `type` {#max.interfaces.TextGenerationRequestTool.type} > type: [str](https://docs.python.org/3/library/stdtypes.html#str) The type of the tool, typically indicating the tool’s category or usage. ## `TextGenerationResponseFormat` {#max.interfaces.TextGenerationResponseFormat} > class max.interfaces.TextGenerationResponseFormat Represents the response format specification for a text generation request. ### `json_schema` {#max.interfaces.TextGenerationResponseFormat.json_schema} > json\_schema: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] A JSON schema dictionary that defines the structure and validation rules for the generated response. ### `type` {#max.interfaces.TextGenerationResponseFormat.type} > type: [str](https://docs.python.org/3/library/stdtypes.html#str) The type of response format, e.g., “json\_object”. ## `TokenBuffer` {#max.interfaces.TokenBuffer} > class max.interfaces.TokenBuffer(array) A dynamically resizable container for managing token sequences. TokenBuffer provides efficient storage and access to token sequences during text generation. It maintains the prompt tokens (initial input) and generated tokens (model output) separately, while handling automatic memory management as new tokens are added. TokenBuffer organizes tokens across three related views: 1. The full stored sequence (all), split into prompt and generated tokens. 2. The processing window (active versus processed and pending tokens). 3. The streaming window over newly generated tokens consumed by callers. The first diagram shows how prompt and generated tokens share a single backing array. Later diagrams explain how processing and streaming walk over that array during generation: ```default +-------------------- all --------------------+ +-----------------+---------------------------+ | prompt | generated | +-----------------+---------------------------+ 0 prompt_length ^ generated_length ^ 0 len(self) ^ ``` This includes three attributes for accessing tokens: * all: The slice of the array containing all valid tokens. * prompt: The slice of the array containing the prompt tokens. * generated: The slice of the array containing the generated tokens. Along with three attributes for tracking their lengths: * prompt\_length: The number of tokens in the prompt. * generated\_length: The number of tokens in the generated tokens. * len(self): The total number of valid tokens in the buffer. Processing window (what the model will process next): ```default +-------------------------------- all -------------------------+ +-------------------+---------------------------+---------------+ | processed | active | pending | +-------------------+---------------------------+---------------+ 0 processed_length ^ active_length ^ 0 current_position ^ 0 len(self) ^ ``` In the above, processed tracks tokens which has already been processed, active tracks tokens, which are scheduled to be processed in the next batch, and pending tracks tokens, which have not yet been processed, but are not actively scheduled to be processed in the next batch (this commonly occurs during chunked prefill). This includes one attribute for accessing tokens: * active: The slice of the array containing the tokens scheduled for processing in the next batch. Along with three additional attributes for tracking their lengths: * processed\_length: The number of tokens that have already been processed. * active\_length: The number of tokens that is currently scheduled for processing in the next batch. * current\_position: The global index marking the end of the current active processing window. This processing view is updated by method such as rewind\_processing, skip\_processing, chunk, and advance\_chunk/advance\_with\_token. Which control how much of the existing sequence is reprocessed or advanced at each step. It also maintains a completion window over the generated tokens for completion streaming: ```default +------------- generated -------------+ +------------+------------------------+ | streamed | ready to stream next | +------------+------------------------+ | (1) | (2) | ``` Generated tokens are conceptually split into: 1. **streamed**: tokens that have already been returned to the caller. 2. **read to stream**: the newest generated tokens that have not yet been returned. Each call to consume\_recently\_generated\_tokens() returns the (2) region and advances the boundary between (1) and (2), so subsequent calls only see newly generated tokens. Together, these three views let TokenBuffer support efficient prompt handling, chunked processing, and incremental streaming while exposing a small, consistent public API.
**Parameters:**
array ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]])
### `active` {#max.interfaces.TokenBuffer.active} > property active: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]] Return the tokens queued for the next processing step. ### `active_length` {#max.interfaces.TokenBuffer.active_length} > property active\_length: [int](https://docs.python.org/3/library/functions.html#int) Count of tokens currently scheduled for processing. ### `actively_chunked` {#max.interfaces.TokenBuffer.actively_chunked} > property actively\_chunked: [bool](https://docs.python.org/3/library/functions.html#bool) Check if the buffer has active chunk limits applied.
**Returns:**
True if chunk limits are active, False otherwise.
### `advance_chunk()` {#max.interfaces.TokenBuffer.advance_chunk} > advance\_chunk() Move to the next set of tokens after a limited chunk. Call this after maybe\_chunk when you have finished working with the current active tokens and want the remaining tokens in the sequence to become active.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If called before maybe\_chunk has limited the active tokens (i.e., when no chunk is currently active).
**Return type:**
None
### `advance_with_token()` {#max.interfaces.TokenBuffer.advance_with_token} > advance\_with\_token(token, mark\_previous\_as\_processed=True) Add a new token to the buffer.
**Parameters:**
* token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add. * mark\_previous\_as\_processed ([bool](https://docs.python.org/3/library/functions.html#bool)) – If False, expands the set of active tokens instead of shifting forward. This is useful for speculative execution scenarios where multiple tokens may be generated.
**Return type:**
None
### `all` {#max.interfaces.TokenBuffer.all} > property all: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]] Return every valid token currently stored (prompt + generated). Use this when downstream components need the full sequence for scoring, logging, or serialization. ### `apply_processing_offset()` {#max.interfaces.TokenBuffer.apply_processing_offset} > apply\_processing\_offset(value) Set the processing offset.
**Parameters:**
value ([int](https://docs.python.org/3/library/functions.html#int)) – The new processing offset.
**Return type:**
None
### `array` {#max.interfaces.TokenBuffer.array} > array: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]] In-place storage holding the prompt plus any generated tokens. ### `chunk()` {#max.interfaces.TokenBuffer.chunk} > chunk(chunk\_size) Limit the upcoming processing step to at most n tokens.
**Parameters:**
chunk\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum number of tokens to process.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If chunk\_size is not between 1 and the current number of active tokens.
**Return type:**
None
### `consume_recently_generated_tokens()` {#max.interfaces.TokenBuffer.consume_recently_generated_tokens} > consume\_recently\_generated\_tokens() Return newly generated tokens since the last consumption.
**Returns:**
A slice containing tokens ready to stream to the caller.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no new tokens are available.
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
### `current_position` {#max.interfaces.TokenBuffer.current_position} > property current\_position: [int](https://docs.python.org/3/library/functions.html#int) Global index marking the end of the current active processing window. Equal to processed\_length + active\_length; represents the index of the next token to be processed, which may be less than the total length when processing is limited by chunking. ### `generated` {#max.interfaces.TokenBuffer.generated} > property generated: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]] Return all tokens produced after the prompt. Use this slice for stop checks, repetition penalties, or any logic that should consider only newly generated content. ### `generated_length` {#max.interfaces.TokenBuffer.generated_length} > property generated\_length: [int](https://docs.python.org/3/library/functions.html#int) Number of tokens generated after the prompt. ### `has_outstanding_generated_tokens` {#max.interfaces.TokenBuffer.has_outstanding_generated_tokens} > property has\_outstanding\_generated\_tokens: [bool](https://docs.python.org/3/library/functions.html#bool) Indicates whether there are generated tokens that have not yet been consumed.
**Returns:**
True if there are outstanding generated tokens to be streamed or processed; False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `overwrite_last_token()` {#max.interfaces.TokenBuffer.overwrite_last_token} > overwrite\_last\_token(token) Overwrite the last token in the buffer.
**Parameters:**
token ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `processed_length` {#max.interfaces.TokenBuffer.processed_length} > property processed\_length: [int](https://docs.python.org/3/library/functions.html#int) Number of tokens that have already been processed. ### `prompt` {#max.interfaces.TokenBuffer.prompt} > property prompt: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]] Return only the original prompt tokens. Helpful for echo suppression, prompt-side metrics, or offset calculations that should exclude generated output. ### `prompt_length` {#max.interfaces.TokenBuffer.prompt_length} > property prompt\_length: [int](https://docs.python.org/3/library/functions.html#int) Number of tokens that belong to the prompt. ### `reset_as_new_prompt()` {#max.interfaces.TokenBuffer.reset_as_new_prompt} > reset\_as\_new\_prompt() Treat the current sequence as a fresh prompt. Marks all existing tokens as prompt tokens so the next generation pass starts from this state.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the buffer state is invalid.
**Return type:**
None
### `rewind_processing()` {#max.interfaces.TokenBuffer.rewind_processing} > rewind\_processing(n) Re-expose n earlier tokens so they can be processed again.
**Parameters:**
n ([int](https://docs.python.org/3/library/functions.html#int)) – Number of tokens to move back into the active window.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n is negative.
**Return type:**
None
### `skip_processing()` {#max.interfaces.TokenBuffer.skip_processing} > skip\_processing(n) Advance the active window start by n tokens.
**Parameters:**
n ([int](https://docs.python.org/3/library/functions.html#int)) – Number of tokens to drop from the active window.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n exceeds the number of available tokens to process, or if skipping n tokens would leave 0 active tokens.
**Return type:**
None
## `VLMTextGenerationContext` {#max.interfaces.VLMTextGenerationContext} > class max.interfaces.VLMTextGenerationContext(\*args, \*\*kwargs) Protocol defining the interface for VLM input contexts. ### `compute_image_aligned_idx()` {#max.interfaces.VLMTextGenerationContext.compute_image_aligned_idx} > compute\_image\_aligned\_idx(idx) Possibly aligns a index value downward if it lies in the middle of an image.
**Parameters:**
idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `image_idx` {#max.interfaces.VLMTextGenerationContext.image_idx} > property image\_idx: [int](https://docs.python.org/3/library/functions.html#int) Index of the next unencoded image in the prompt. ### `images` {#max.interfaces.VLMTextGenerationContext.images} > property images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](#max.interfaces.ImageMetadata)] Returns the images in the context. ### `needs_vision_encoding` {#max.interfaces.VLMTextGenerationContext.needs_vision_encoding} > property needs\_vision\_encoding: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether vision encoding is needed for this context. ### `next_images` {#max.interfaces.VLMTextGenerationContext.next_images} > property next\_images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](#max.interfaces.ImageMetadata)] Returns the images that are not yet encoded. ## `drain_queue()` {#max.interfaces.drain_queue} > max.interfaces.drain\_queue(pull\_queue, max\_items=None) Remove and return items from the queue without blocking. This method is expected to return an empty list if the queue is empty. If max\_items is specified, at most that many items will be returned.
**Parameters:**
* pull\_queue ([MAXPullQueue](#max.interfaces.MAXPullQueue)\[PullItemType]) – The queue to drain items from. * max\_items ([int](https://docs.python.org/3/library/functions.html#int) | None) – Maximum number of items to return. If None, returns all items.
**Returns:**
List of items removed from the queue, limited by max\_items if specified.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[PullItemType]
## `get_blocking()` {#max.interfaces.get_blocking} > max.interfaces.get\_blocking(pull\_queue) Get the next item from the queue. If no item is available, this method will spin until one is.
**Parameters:**
pull\_queue ([MAXPullQueue](#max.interfaces.MAXPullQueue)\[PullItemType])
**Return type:**
PullItemType
## `msgpack_numpy_decoder()` {#max.interfaces.msgpack_numpy_decoder} > max.interfaces.msgpack\_numpy\_decoder(type\_, copy=False) Create a decoder function for the specified type.
**Parameters:**
* type – The type to decode into * copy ([bool](https://docs.python.org/3/library/functions.html#bool)) – Copy numpy arrays if true. Defaults to True. Copy is set to True by default because most downstream usage of deserialized tensors are MAX driver tensors, which require owned numpy arrays. This is a constraint imposed by dlpack & numpy where we cannot create a buffer from read-only data. While there is a performance benefit during deserialization to removing copies by default, this often just moves the work downstream to an implicit copy during Buffer.from\_numpy. As a result, it is easier to make the copy explicit here and maintain the pattern that all numpy arrays used in MAX are owned by the current process. * type\_ ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Returns:**
A pickleable decoder instance that decodes bytes into the specified type
**Return type:**
MsgpackNumpyDecoder
## `msgpack_numpy_encoder()` {#max.interfaces.msgpack_numpy_encoder} > max.interfaces.msgpack\_numpy\_encoder(use\_shared\_memory=False, shared\_memory\_threshold=0) Create an encoder function that handles numpy arrays.
**Parameters:**
* use\_shared\_memory ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to attempt shared memory conversion for numpy arrays * shared\_memory\_threshold ([int](https://docs.python.org/3/library/functions.html#int)) – Minimum size in bytes for shared memory conversion. If 0, all arrays are candidates for conversion.
**Returns:**
A pickleable encoder instance that encodes objects into bytes
**Return type:**
MsgpackNumpyEncoder
--- ## kv_cache KV cache management for efficient attention computation during inference. This package provides implementations for managing key-value caches used in transformer models. The paged attention implementation enables efficient memory management by fragmenting cache memory into pages, allowing for better memory utilization and support for prefix caching. ## Functions * [`load_kv_manager`](/max/api/python/kv_cache/registry): Load and initialize a KV cache manager. * [`estimate_kv_cache_size`](/max/api/python/kv_cache/registry): Estimate KV cache memory requirements. * [`infer_optimal_batch_size`](/max/api/python/kv_cache/registry): Infer optimal batch size based on available cache memory. * [`available_port`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Find an available TCP port for transfer engine communication. ## Modules * [`registry`](/max/api/python/kv_cache/registry): KV cache manager factory functions and utilities. ## Packages * [`paged_kv_cache`](/max/api/python/kv_cache/paged_kv_cache): Paged attention KV cache implementation. ## Classes * [`PagedKVCacheManager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Manager for paged KV cache with data and tensor parallelism support. * [`KVTransferEngine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Manages KV cache transfers between devices in distributed settings. * [`KVTransferEngineMetadata`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Metadata for KV cache transfer engine configuration. * [`TransferReqData`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Data structure for KV cache transfer requests. --- ## cache_manager ## `PagedKVCacheManager` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager} > class max.kv\_cache.paged\_kv\_cache.cache\_manager.PagedKVCacheManager(params, session, total\_num\_pages, total\_num\_host\_pages=0, enable\_runtime\_checks=False) Paged KVCache manager with data and tensor parallelism support. ```python kv_manager.claim(ctx1.request_id, replica_idx=0) kv_manager.claim(ctx2.request_id, replica_idx=1) # Allocate blocks for these requests kv_manager.alloc(ctx1, replica_idx=0, num_steps=10) kv_manager.alloc(ctx2, replica_idx=1, num_steps=10) # Get KVCache inputs to feed to graph kv_cache_inputs = kv_manager.get_runtime_inputs( [[ctx1, ctx2]], num_steps=10 ) # Run model... # Update requests with newly generated tokens ctx1.update(42) ctx2.update(42) # Commit newly written blocks to prefix cache kv_manager.step([[ctx1, ctx2]]) # Release metadata and KV blocks for these requests kv_manager.release(ctx1.request_id, replica_idx=0) kv_manager.release(ctx2.request_id, replica_idx=1) ```
**Parameters:**
* params ([KVCacheParams](../../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * session ([InferenceSession](../../engine.md#max.engine.InferenceSession)) * total\_num\_pages ([int](https://docs.python.org/3/library/functions.html#int)) * total\_num\_host\_pages ([int](https://docs.python.org/3/library/functions.html#int)) * enable\_runtime\_checks ([bool](https://docs.python.org/3/library/functions.html#bool))
### `alloc()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.alloc} > alloc(data, replica\_idx, num\_steps=1) Allocates blocks for a request to run for N steps. This method allocates blocks needed by a request to run for N steps. When prefix caching is enabled, some of the allocated blocks may be retrieved from the prefix cache.
**Parameters:**
* data ([TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)) – The text generation context for the request. The request ID must already be assigned to a replica via claim. * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – The number of steps to reserve blocks for. Default: 1. * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Raises:**
* InsufficientBlocksError – If there are insufficient free blocks to * satisfy the allocation. –
**Return type:**
None
### `claim()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.claim} > claim(request\_id, replica\_idx) Reserve a sequence ID for the given request ID.
**Parameters:**
* request\_id ([RequestID](../../interfaces.md#max.interfaces.RequestID)) * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `contains()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.contains} > contains(request\_id, replica\_idx)
**Parameters:**
* request\_id ([RequestID](../../interfaces.md#max.interfaces.RequestID)) * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `get_device_tensors()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_device_tensors} > get\_device\_tensors(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../../driver.md#max.driver.Buffer)]
### `get_metrics()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_metrics} > get\_metrics(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
KVCacheMetrics
### `get_num_host_pages()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_num_host_pages} > get\_num\_host\_pages(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `get_num_pages()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_num_pages} > get\_num\_pages(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `get_num_used_host_pages()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_num_used_host_pages} > get\_num\_used\_host\_pages(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `get_num_used_pages()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_num_used_pages} > get\_num\_used\_pages(replica\_idx)
**Parameters:**
replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `get_pct_used_blocks_after_allocation()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_pct_used_blocks_after_allocation} > get\_pct\_used\_blocks\_after\_allocation(ctx, replica\_idx, num\_steps=1) Get the percentage of blocks used after allocating for a request.
**Parameters:**
* ctx ([TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)) – The request context containing sequence information and token indices. * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of additional steps to allocate blocks for. Defaults to 1. * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Returns:**
The percentage of total blocks used after allocating for the request.
**Return type:**
[float](https://docs.python.org/3/library/functions.html#float)
### `get_req_blocks()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_req_blocks} > get\_req\_blocks(request\_id, replica\_idx)
**Parameters:**
* request\_id ([RequestID](../../interfaces.md#max.interfaces.RequestID)) * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
### `get_runtime_inputs()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_runtime_inputs} > get\_runtime\_inputs(batches, num\_steps=1) Get the graph inputs for per-replica batches of requests. This method will raise a RuntimeError if any request has insufficient blocks already allocated to it to run for the given number of steps.
**Parameters:**
* batches ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)]]) – Per-replica batches of requests * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of steps to run for
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[RaggedKVCacheInputs]
### `increment_cache_lengths()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.increment_cache_lengths} > increment\_cache\_lengths(kv\_cache\_inputs, prev\_model\_inputs)
**Parameters:**
* kv\_cache\_inputs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[RaggedKVCacheInputs]) * prev\_model\_inputs ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[RaggedKVCacheInputs]
### `infer_optimal_batch_size()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.infer_optimal_batch_size} > classmethod infer\_optimal\_batch\_size(params, max\_seq\_len, available\_cache\_memory, devices, \*\*kwargs)
**Parameters:**
* params ([KVCacheParamInterface](../../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) * devices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Device](../../driver.md#max.driver.Device)]) * kwargs ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `release()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.release} > release(request\_id, replica\_idx)
**Parameters:**
* request\_id ([RequestID](../../interfaces.md#max.interfaces.RequestID)) * replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `reset_metrics()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.reset_metrics} > reset\_metrics()
**Return type:**
None
### `reset_prefix_cache()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.reset_prefix_cache} > reset\_prefix\_cache()
**Return type:**
None
### `step()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.step} > step(batches) Commit new tokens into the prefix cache for per-replica batches.
**Parameters:**
batches ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)]])
**Return type:**
None
--- ## paged_kv_cache Paged attention KV cache implementation with support for distributed inference. This package provides the core implementation of paged KV cache management, including cache managers, transfer engines for distributed settings, and tensor parallelism support. ## Modules * [`cache_manager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Core paged KV cache manager implementation. * [`tp_cache_manager`](/max/api/python/kv_cache/paged_kv_cache/tp_cache_manager): Tensor parallelism cache manager and input symbols. * [`transfer_engine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): KV cache transfer engine for distributed inference. ## Classes * [`PagedKVCacheManager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Manager for paged KV cache with data and tensor parallelism support. * [`KVTransferEngine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Manages KV cache transfers between devices. * [`KVTransferEngineMetadata`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Metadata for transfer engine configuration. * [`TransferReqData`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Transfer request data structure. --- ## tp_cache_manager PagedAttention-enabled KV cache for the Transformer leveraging the mo.opaque pattern. --- ## transfer_engine KVCache Transfer Engine ## `KVTransferEngine` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine} > class max.kv\_cache.paged\_kv\_cache.transfer\_engine.KVTransferEngine(name, tensors, \*, total\_num\_pages) KVCache Transfer Engine with support for Data Parallelism (DP) and Tensor Parallelism (TP). The engine accepts a 2D list of tensors: list\[list\[Buffer]] where the outer list represents DP replicas and the inner list represents TP shards within each replica. The TransferEngine communicates with other TransferEngines in other threads or processes. However, individual TransferEngines themselves are not thread-safe. It is intended to be used by MAX’s single-threaded scheduler.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * tensors (Sequence\[Sequence\[[Buffer](../../driver.md#max.driver.Buffer)]]) * total\_num\_pages ([int](https://docs.python.org/3/library/functions.html#int))
### `bytes_per_page` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.bytes_per_page} > bytes\_per\_page: [int](https://docs.python.org/3/library/functions.html#int) Bytes per page for each tensor. ### `cleanup()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.cleanup} > cleanup() Release all resources associated with the transfer engine. Should be called before the transfer engine is garbage collected. Moving this logic into the \_\_del\_\_ destructor does causes a UCX error for unknown reasons.
**Return type:**
None
### `cleanup_transfer()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.cleanup_transfer} > cleanup\_transfer(transfer\_req) Cleanup a transfer. This should be called after a transfer is complete.
**Parameters:**
transfer\_req ([TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)) – The transfer request to cleanup.
**Return type:**
None
### `completed_recv_transfers` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.completed_recv_transfers} > completed\_recv\_transfers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [int](https://docs.python.org/3/library/functions.html#int)]] Map of agent names to completed recv transfers. ### `connect()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.connect} > connect(remote) Connect to a remote engine (all replicas).
**Parameters:**
remote ([KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)) – Metadata for the remote engine (all replicas).
**Return type:**
None
### `dp` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.dp} > dp: [int](https://docs.python.org/3/library/functions.html#int) Number of DP replicas. ### `inflight_send_transfers` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.inflight_send_transfers} > inflight\_send\_transfers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)] Map of transfer names to send transfer request data. ### `initiate_send_transfer()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.initiate_send_transfer} > initiate\_send\_transfer(remote\_metadata, src\_idxs, dst\_idxs, src\_replica\_idx, dst\_replica\_idx) Initiate a transfer from current engine to remote engine. The same page indices are broadcast to all TP shards within the source and destination replicas.
**Parameters:**
* remote\_metadata ([KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)) – Metadata for the remote engine. * src\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – List of indices of the source pages in the current engine. * dst\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – List of indices of the destination pages in the remote engine. * src\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index of the source replica to transfer from. * dst\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index of the destination replica to transfer to.
**Return type:**
[TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)
### `is_complete()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.is_complete} > is\_complete(transfer\_req) Checks if a given send or recv transfer is completed. :::caution Caution This method is prone to infinite loops. For the transfer to progress, the remote engine MUST call wait\_recv\_complete. As such, the following code will hang: ```python transfer_req = engine_1.write_to(...) while not engine_1.is_complete(transfer_req): pass while not engine_2.is_complete(transfer_req): pass ``` Instead do: ```python transfer_req = engine_1.write_to(...) while not engine_1.is_complete(transfer_req) or not engine_2.is_complete(transfer_req): pass ``` :::
**Parameters:**
transfer\_req ([TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)) – The transfer request.
**Returns:**
True if all transfers have completed; false otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `memory_type` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.memory_type} > memory\_type: MemoryType Type of memory being managed (e.g. DRAM). ### `metadata` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.metadata} > property metadata: [KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata) Get metadata for all replicas.
**Returns:**
Metadata for the entire engine (all replicas).
### `name` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) Name of transfer engine / nixl agent. ### `remote_agent_to_engine` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.remote_agent_to_engine} > remote\_agent\_to\_engine: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)] Map of remote agent names to their engine names. ### `remote_connections` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.remote_connections} > remote\_connections: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)] Map of remote engine names to their metadata. ### `sync_and_release()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.sync_and_release} > sync\_and\_release(transfer\_req) Wait for a transfer to complete and release the transfer after it completes.
**Parameters:**
transfer\_req ([TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData))
**Return type:**
None
### `tensor_agents` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.tensor_agents} > tensor\_agents: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorAgent](#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent)]] \[replica]\[tp\_shard].
**Type:**
2D list of TensorAgent objects
### `total_num_pages` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.total_num_pages} > total\_num\_pages: [int](https://docs.python.org/3/library/functions.html#int) Total number of pages in each tensor (same across all replicas). ### `tp` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.tp} > tp: [int](https://docs.python.org/3/library/functions.html#int) Number of TP shards per replica. ## `KVTransferEngineMetadata` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata} > class max.kv\_cache.paged\_kv\_cache.transfer\_engine.KVTransferEngineMetadata(\*, name, total\_num\_pages, bytes\_per\_page, memory\_type, hostname, agents\_meta) Metadata associated with a transfer engine. This is safe to send between threads/processes.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * total\_num\_pages ([int](https://docs.python.org/3/library/functions.html#int)) * bytes\_per\_page ([int](https://docs.python.org/3/library/functions.html#int)) * memory\_type (MemoryType) * hostname ([str](https://docs.python.org/3/library/stdtypes.html#str)) * agents\_meta ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorAgentMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata)]])
### `agents_meta` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.agents_meta} > agents\_meta: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorAgentMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata)]] \[replica]\[tp\_shard].
**Type:**
Metadata for each replica’s agents
### `bytes_per_page` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.bytes_per_page} > bytes\_per\_page: [int](https://docs.python.org/3/library/functions.html#int) Bytes per page for each tensor. ### `hostname` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.hostname} > hostname: [str](https://docs.python.org/3/library/stdtypes.html#str) Hostname of the machine that the transfer engine is running on. ### `memory_type` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.memory_type} > memory\_type: MemoryType Memory type of the transfer engine. ### `name` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) Base name of the transfer engine. ### `total_num_pages` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.total_num_pages} > total\_num\_pages: [int](https://docs.python.org/3/library/functions.html#int) Total number of pages in each tensor. ## `TensorAgent` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent} > class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TensorAgent(agent, agent\_name, tensor, base\_addr, ucx\_backend, device\_id, agent\_metadata, reg\_dlist) Manages a single tensor and its associated NIXL agent for transfers. This class holds both the runtime state (live objects) and can generate the serializable metadata for communication between engines.
**Parameters:**
* agent (Agent) * agent\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * tensor ([Buffer](../../driver.md#max.driver.Buffer)) * base\_addr ([int](https://docs.python.org/3/library/functions.html#int)) * ucx\_backend ([int](https://docs.python.org/3/library/functions.html#int)) * device\_id ([int](https://docs.python.org/3/library/functions.html#int)) * agent\_metadata ([bytes](https://docs.python.org/3/library/stdtypes.html#bytes)) * reg\_dlist (RegistrationDescriptorList)
### `agent` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.agent} > agent: Agent NIXL agent for this tensor. ### `agent_metadata` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.agent_metadata} > agent\_metadata: [bytes](https://docs.python.org/3/library/stdtypes.html#bytes) Metadata for this agent. ### `agent_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.agent_name} > agent\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Name of this agent. ### `base_addr` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.base_addr} > base\_addr: [int](https://docs.python.org/3/library/functions.html#int) Base memory address for this tensor. ### `create_agent()` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.create_agent} > classmethod create\_agent(agent\_name, listen\_port, tensor, total\_num\_pages, elts\_per\_page, memory\_type)
**Parameters:**
* agent\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * listen\_port ([int](https://docs.python.org/3/library/functions.html#int)) * tensor ([Buffer](../../driver.md#max.driver.Buffer)) * total\_num\_pages ([int](https://docs.python.org/3/library/functions.html#int)) * elts\_per\_page ([int](https://docs.python.org/3/library/functions.html#int)) * memory\_type (MemoryType)
**Return type:**
[TensorAgent](#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent)
### `device_id` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.device_id} > device\_id: [int](https://docs.python.org/3/library/functions.html#int) Device ID for this tensor. ### `reg_dlist` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.reg_dlist} > reg\_dlist: RegistrationDescriptorList Registration descriptor list for this tensor. ### `tensor` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.tensor} > tensor: [Buffer](../../driver.md#max.driver.Buffer) Tensor for this agent. ### `to_metadata()` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.to_metadata} > to\_metadata() Convert to serializable metadata for communication.
**Return type:**
[TensorAgentMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata)
### `ucx_backend` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.ucx_backend} > ucx\_backend: [int](https://docs.python.org/3/library/functions.html#int) UCX backend for this tensor. ## `TensorAgentMetadata` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata} > class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TensorAgentMetadata(\*, agent\_name, metadata, base\_addr, device\_id) Metadata for a single tensor/agent in the transfer engine. This is used for serialization and communication between engines.
**Parameters:**
* agent\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * metadata ([bytes](https://docs.python.org/3/library/stdtypes.html#bytes)) * base\_addr ([int](https://docs.python.org/3/library/functions.html#int)) * device\_id ([int](https://docs.python.org/3/library/functions.html#int))
### `agent_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.agent_name} > agent\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Name of this agent. ### `base_addr` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.base_addr} > base\_addr: [int](https://docs.python.org/3/library/functions.html#int) Base memory address for this tensor. ### `device_id` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.device_id} > device\_id: [int](https://docs.python.org/3/library/functions.html#int) Device ID for this tensor. ### `metadata` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.metadata} > metadata: [bytes](https://docs.python.org/3/library/stdtypes.html#bytes) Metadata for this agent. ## `TransferReqData` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData} > class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TransferReqData(\*, dst\_name, src\_name, transfer\_name, transfer\_ids, src\_idxs, dst\_idxs, src\_replica\_idx, dst\_replica\_idx) Metadata associated with a transfer request. This is safe to send between threads/processes.
**Parameters:**
* dst\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * src\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * transfer\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * transfer\_ids ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * src\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * dst\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * src\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int)) * dst\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
### `dst_idxs` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_idxs} > dst\_idxs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Length of destination indices can differ from len(transfer\_ids). ### `dst_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_name} > dst\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Base name of destination engine. ### `dst_replica_idx` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_replica_idx} > dst\_replica\_idx: [int](https://docs.python.org/3/library/functions.html#int) Index of the destination replica this transfer is to. ### `src_idxs` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_idxs} > src\_idxs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Length of source indices can differ from len(transfer\_ids). ### `src_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_name} > src\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Base name of source engine. ### `src_replica_idx` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_replica_idx} > src\_replica\_idx: [int](https://docs.python.org/3/library/functions.html#int) Index of the source replica this transfer is from. ### `transfer_ids` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.transfer_ids} > transfer\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Transfer IDs (one per TP shard in the replica). ### `transfer_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.transfer_name} > transfer\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Transfer name. ## `available_port()` {#max.kv_cache.paged_kv_cache.transfer_engine.available_port} > max.kv\_cache.paged\_kv\_cache.transfer\_engine.available\_port(start\_port=8000, end\_port=9000, max\_attempts=100) Find an available TCP port in the given range.
**Parameters:**
* start\_port ([int](https://docs.python.org/3/library/functions.html#int)) – The lower bound of the port range (inclusive). * end\_port ([int](https://docs.python.org/3/library/functions.html#int)) – The upper bound of the port range (inclusive). * max\_attempts ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum number of attempts to find a free port.
**Returns:**
An available port number.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
**Raises:**
[RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If no available port is found after max\_attempts.
--- ## registry ## `estimate_kv_cache_size()` {#max.kv_cache.registry.estimate_kv_cache_size} > max.kv\_cache.registry.estimate\_kv\_cache\_size(params, max\_batch\_size, max\_seq\_len, available\_cache\_memory)
**Parameters:**
* params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `infer_optimal_batch_size()` {#max.kv_cache.registry.infer_optimal_batch_size} > max.kv\_cache.registry.infer\_optimal\_batch\_size(params, max\_seq\_len, available\_cache\_memory, devices, \*\*kwargs)
**Parameters:**
* params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) * devices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Device](../driver.md#max.driver.Device)]) * kwargs ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `load_kv_manager()` {#max.kv_cache.registry.load_kv_manager} > max.kv\_cache.registry.load\_kv\_manager(params, max\_batch\_size, max\_seq\_len, session, available\_cache\_memory) Loads a single KV cache manager from the given params.
**Parameters:**
* params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * session ([InferenceSession](../engine.md#max.engine.InferenceSession)) * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[PagedKVCacheManager](paged_kv_cache/cache_manager.md#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager)
## `load_kv_managers()` {#max.kv_cache.registry.load_kv_managers} > max.kv\_cache.registry.load\_kv\_managers(params, max\_batch\_size, max\_seq\_len, session, available\_cache\_memory) Loads (potentially multiple) KV cache managers from the given params.
**Parameters:**
* params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * session ([InferenceSession](../engine.md#max.engine.InferenceSession)) * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[PagedKVCacheManager](paged_kv_cache/cache_manager.md#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager)]
--- ## Embedding (Nn) ## `Embedding` {#max.nn.Embedding} > class max.nn.Embedding(vocab\_size, \*, dim=None, dims=None) A vector embedding. An embedding can be thought of as a lookup table for vectors by index. Given an input tensor of indices into the embedding, the result of the embedding lookup is a tensor of the same shape, but with each index replaced by the value of the vector in that location in the embedding table. The common case for embeddings is a 1-dimensional embedding: ```python from max.dtype import DType from max.tensor import Tensor from max.nn import Embedding embedding = Embedding(vocab_size=1000, dim=128) tokens = Tensor.ones([10], dtype=DType.uint64) embedded = embedding(tokens) assert embedded.shape == [10, 128] ``` However they just as easily support multi-dimensional embeddings: ```python from max.dtype import DType from max.tensor import Tensor from max.nn import Embedding embedding = Embedding(vocab_size=1000, dims=[16, 128]) tokens = Tensor.ones([10], dtype=DType.uint64) embedded = embedding(tokens) assert embedded.shape == [10, 16, 128] ```
**Parameters:**
* vocab\_size (DimLike) * dim (DimLike | None) * dims (ShapeLike | None)
### `dim` {#max.nn.Embedding.dim} > property dim: [Dim](../graph/dim.md#max.graph.dim.Dim) The dimension of the vectors in the embedding (for a 1d embedding). Raises: For 0- or >1-dimensional embeddings. ### `dims` {#max.nn.Embedding.dims} > property dims: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Dim](../graph/dim.md#max.graph.dim.Dim)] The dimensions of the vectors in the embedding. ### `forward()` {#max.nn.Embedding.forward} > forward(indices) Applies the vector embedding to the input tensor of indices.
**Parameters:**
indices ([Tensor](../tensor.md#max.tensor.Tensor)) – An integer-valued tensor. Values must be in the range \[0, vocab\_size) for the embedding.
**Returns:**
A dense tensor made by looking up each index in the vector embedding. For an input of shape `(*batch, indices)` and an embedding of shape `(vocab_size, *dims)`, the result will have shape `(*batch, indices, *dims)`.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
### `vocab_size` {#max.nn.Embedding.vocab_size} > property vocab\_size: [Dim](../graph/dim.md#max.graph.dim.Dim) The vocab size of the embedding. Indices outside the range of \[0, index\_size) are illegal. ### `weight` {#max.nn.Embedding.weight} > weight: [Tensor](../tensor.md#max.tensor.Tensor) :::note Note For the legacy graph-based embedding layer, see [legacy/embedding](/max/api/python/nn/legacy/embedding). ::: --- ## Linear ## `Linear` {#max.nn.Linear} > class max.nn.Linear(in\_dim, out\_dim, \*, bias=True) A unary linear transformation over an input tensor. Linear is defined as f(x) = x @ W\.T + B where W is the weight tensor and B is an optional bias tensor. If W is not square then the transformation represents a dimensionality change. By convention the weight tensor is stored transposed. ```python from max.nn import Linear from max.tensor import Tensor model = Linear(5, 10) assert dict(model.parameters) == { "weight": model.weight, "bias": model.bias } result = model(Tensor.ones([5])) assert result.shape == [10] ```
**Parameters:**
* in\_dim (DimLike) * out\_dim (DimLike) * bias ([Tensor](../tensor.md#max.tensor.Tensor) | [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\[0])
### `bias` {#max.nn.Linear.bias} > bias: [Tensor](../tensor.md#max.tensor.Tensor) | [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\[0] The bias `Tensor` for the linear transformation (or 0 if bias is disabled). ### `forward()` {#max.nn.Linear.forward} > forward(x) Applies a linear transformation to the input tensor. Linear is defined as f(x) = x @ W\.T + B where W is the weight tensor and B is an optional bias tensor.
**Parameters:**
x ([Tensor](../tensor.md#max.tensor.Tensor)) – The input tensor
**Returns:**
The result of applying the linear transformation to the tensor.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
### `in_dim` {#max.nn.Linear.in_dim} > property in\_dim: [Dim](../graph/dim.md#max.graph.dim.Dim) The input dimension for the transformation. ### `out_dim` {#max.nn.Linear.out_dim} > property out\_dim: [Dim](../graph/dim.md#max.graph.dim.Dim) The output dimension for the transformation. ### `weight` {#max.nn.Linear.weight} > weight: [Tensor](../tensor.md#max.tensor.Tensor) The weight `Tensor` for the linear transformation. :::note Note For the legacy graph-based linear layer, see [legacy/linear](/max/api/python/nn/legacy/linear). ::: --- ## nn APIs to build neural network components for deep learning models with Python. The MAX neural network API provides two namespaces: * **max.nn**: Eager-style execution. * **max.nn.legacy**: Legacy graph-based API (for backward compatibility). For functional operations like relu, softmax, and more, see the [`functional`](/max/api/python/functional) module. ## Core API Use these modules for all models. They provide eager-style execution with PyTorch-style syntax. * [`Embedding`](/max/api/python/nn/Embedding): Vector embedding layer for token representation. * [`Linear`](/max/api/python/nn/Linear): Linear transformation layer with weights and bias. * [`module`](/max/api/python/nn/module): Base class for all neural network modules. * [`norm`](/max/api/python/nn/norm): Normalization layers for training stability. * [`rope`](/max/api/python/nn/rope): Rotary position embeddings for sequence models. * [`sequential`](/max/api/python/nn/sequential): Containers for composing modules sequentially. ## Legacy API :::note Note The legacy API remains available for backward compatibility. For all new models, use the max.nn API. ::: The legacy API provides graph-based layer implementations. See the full reference: * [`legacy`](/max/api/python/nn/legacy): Neural network legacy API documentation. --- ## attention_with_rope An opaque KV Cache optimized attention mechanism with Rope. ## `AttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope} > class max.nn.legacy.attention.attention\_with\_rope.AttentionWithRope(\*, rope, sharding\_strategy=None, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06) Implementation of attention that uses Rotary Position Embedding (RoPE).
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * sharding\_strategy (ShardingStrategy | None) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices (Sequence\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * float8\_config ([Float8Config](../float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None) * use\_qk\_norm ([bool](https://docs.python.org/3/library/functions.html#bool)) * rms\_norm\_eps ([float](https://docs.python.org/3/library/functions.html#float))
### `qkv_input_scale` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_input_scale} > property qkv\_input\_scale: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The max of q, k, and v scale input vectors. ### `qkv_weight_scale` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_weight_scale} > property qkv\_weight\_scale: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The max of q, k, and v scale weight vectors. ### `qkv_weight_scale_2` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_weight_scale_2} > property qkv\_weight\_scale\_2: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The max of q, k, and v scale input vectors. ### `rope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding) ### `shard()` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.shard} > shard(devices) Create sharded views across devices (tensor-parallel). Returns one AttentionWithRope per device with appropriately sliced weights.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[AttentionWithRope](#max.nn.legacy.attention.attention_with_rope.AttentionWithRope)]
### `sharding_strategy` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the Module sharding strategy. ### `wqkv` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.wqkv} > property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ### `wqkv_bias` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.wqkv_bias} > property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The concatenation of q, k, and v bias weight vectors. ## `AttentionWithRopeNoOpaque` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque} > class max.nn.legacy.attention.attention\_with\_rope.AttentionWithRopeNoOpaque(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, scale=None) Attention with RoPE without opaque KV cache. Assumes: : - no float8 * no stacked qkv * no bias * no clip\_qkv * no float8\_config
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices (Sequence\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None)
### `rope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding) ## `DataParallelAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.DataParallelAttentionWithRope} > class max.nn.legacy.attention.attention\_with\_rope.DataParallelAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06) Data-parallel implementation of Attention with RoPE. This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs\_cis, input\_row\_offsets). No collective ops are required; KV-cache remains local to each device. **Notes:** * Assumes the caller has already distributed xs, kv\_collections, freqs\_cis, and input\_row\_offsets so that index i corresponds to device i, with input\_row\_offsets\[i] rebased to start at 0.
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices (Sequence\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * float8\_config ([Float8Config](../float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None) * use\_qk\_norm ([bool](https://docs.python.org/3/library/functions.html#bool)) * rms\_norm\_eps ([float](https://docs.python.org/3/library/functions.html#float))
## `GGUFQAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope} > class max.nn.legacy.attention.attention\_with\_rope.GGUFQAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, dtype, quantization\_encoding, devices=None, linear\_cls=\, scale=None, has\_bias=False, clip\_qkv=None) Implementation of attention with GGUF quantized weights.
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * quantization\_encoding ([QuantizationEncoding](../../../graph/quantization.md#max.graph.quantization.QuantizationEncoding)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None)
### `rope` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding) ### `wqkv` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv} > property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ### `wqkv_bias` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv_bias} > property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The concatenation of q, k, and v bias weight vectors. ## `GPTQAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope} > class max.nn.legacy.attention.attention\_with\_rope.GPTQAttentionWithRope(quantization\_config, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, scale=None, linear\_cls=\) Implementation of the GPTQ attention layer.
**Parameters:**
* quantization\_config ([QuantizationConfig](../../../graph/quantization.md#max.graph.quantization.QuantizationConfig)) * rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)])
### `wqkv` {#max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope.wqkv} > property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors (packed + scales). ## `TensorParallelAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.TensorParallelAttentionWithRope} > class max.nn.legacy.attention.attention\_with\_rope.TensorParallelAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06) Tensor-parallel wrapper that delegates sharding to the base module.
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices (Sequence\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * float8\_config ([Float8Config](../float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None) * use\_qk\_norm ([bool](https://docs.python.org/3/library/functions.html#bool)) * rms\_norm\_eps ([float](https://docs.python.org/3/library/functions.html#float))
## `distribute_value()` {#max.nn.legacy.attention.attention_with_rope.distribute_value} > max.nn.legacy.attention.attention\_with\_rope.distribute\_value(v, devices)
**Parameters:**
* v ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]
--- ## attention (Attention) Legacy attention mechanisms for graph-based neural networks. ## Modules * [`attention_with_rope`](/max/api/python/nn/legacy/attention/attention_with_rope): Attention with rotary position embeddings. * [`interfaces`](/max/api/python/nn/legacy/attention/interfaces): Attention interface definitions. * [`mask_config`](/max/api/python/nn/legacy/attention/mask_config): Attention mask configuration utilities. * [`multi_latent_attention`](/max/api/python/nn/legacy/attention/multi_latent_attention): Multi-latent attention mechanism. * [`multihead_attention`](/max/api/python/nn/legacy/attention/multihead_attention): Multi-head attention implementation. * [`ragged_attention`](/max/api/python/nn/legacy/attention/ragged_attention): Attention for variable-length sequences. --- ## interfaces (Attention) General interface for Attention. ## `DistributedAttentionImpl` {#max.nn.legacy.attention.interfaces.DistributedAttentionImpl} > class max.nn.legacy.attention.interfaces.DistributedAttentionImpl A generalized Distributed attention interface. --- ## mask_config Mask configuration for attention. ## `AttentionMaskVariant` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant} > class max.nn.legacy.attention.mask\_config.AttentionMaskVariant(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant.CAUSAL_MASK} > CAUSAL\_MASK = 'causal' ### `CHUNKED_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant.CHUNKED_CAUSAL_MASK} > CHUNKED\_CAUSAL\_MASK = 'chunked\_causal' ### `NULL_MASK` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant.NULL_MASK} > NULL\_MASK = 'null' ### `SLIDING_WINDOW_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant.SLIDING_WINDOW_CAUSAL_MASK} > SLIDING\_WINDOW\_CAUSAL\_MASK = 'sliding\_window\_causal' ### `TENSOR_MASK` {#max.nn.legacy.attention.mask_config.AttentionMaskVariant.TENSOR_MASK} > TENSOR\_MASK = 'tensor\_mask' ## `MHAMaskConfig` {#max.nn.legacy.attention.mask_config.MHAMaskConfig} > class max.nn.legacy.attention.mask\_config.MHAMaskConfig(attention\_mask\_variant: 'AttentionMaskVariant', positional\_encoding\_variant: 'PositionalEncodingVariant')
**Parameters:**
* attention\_mask\_variant ([AttentionMaskVariant](#max.nn.legacy.attention.mask_config.AttentionMaskVariant)) * positional\_encoding\_variant ([PositionalEncodingVariant](#max.nn.legacy.attention.mask_config.PositionalEncodingVariant))
### `attention_mask_variant` {#max.nn.legacy.attention.mask_config.MHAMaskConfig.attention_mask_variant} > attention\_mask\_variant: [AttentionMaskVariant](#max.nn.legacy.attention.mask_config.AttentionMaskVariant) ### `positional_encoding_variant` {#max.nn.legacy.attention.mask_config.MHAMaskConfig.positional_encoding_variant} > positional\_encoding\_variant: [PositionalEncodingVariant](#max.nn.legacy.attention.mask_config.PositionalEncodingVariant) ## `MHAMaskVariant` {#max.nn.legacy.attention.mask_config.MHAMaskVariant} > class max.nn.legacy.attention.mask\_config.MHAMaskVariant(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `CAUSAL_ALIBI_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CAUSAL_ALIBI_MASK} > CAUSAL\_ALIBI\_MASK = '1' ### `CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CAUSAL_MASK} > CAUSAL\_MASK = '0' ### `CHUNKED_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CHUNKED_CAUSAL_MASK} > CHUNKED\_CAUSAL\_MASK = '3' ### `NULL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.NULL_MASK} > NULL\_MASK = '2' ### `SLIDING_WINDOW_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.SLIDING_WINDOW_CAUSAL_MASK} > SLIDING\_WINDOW\_CAUSAL\_MASK = '4' ## `PositionalEncodingVariant` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant} > class max.nn.legacy.attention.mask\_config.PositionalEncodingVariant(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ALIBI_POS` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant.ALIBI_POS} > ALIBI\_POS = 'alibi\_pos' ### `NO_POS` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant.NO_POS} > NO\_POS = 'no\_pos' --- ## multi_latent_attention An opaque KV Cache optimized attention mechanism with Rope. ## `DataParallelLatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.DataParallelLatentAttentionWithRope} > class max.nn.legacy.attention.multi\_latent\_attention.DataParallelLatentAttentionWithRope(\*\*kwargs) Data-parallel implementation of Latent Attention with RoPE. This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs\_cis, input\_row\_offsets). No collective ops are required; KV-cache remains local to each device. **Notes:** * signal\_buffers is accepted for interface parity with the distributed implementation but is not used here. * Assumes the caller has already distributed xs, kv\_collections, freqs\_cis, and input\_row\_offsets so that index i corresponds to device i, with input\_row\_offsets\[i] rebased to start at 0. ### `create_mla_prefill_metadata()` {#max.nn.legacy.attention.multi_latent_attention.DataParallelLatentAttentionWithRope.create_mla_prefill_metadata} > create\_mla\_prefill\_metadata(input\_row\_offsets\_, kv\_collections)
**Parameters:**
* input\_row\_offsets\_ ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]) * kv\_collections ([list](https://docs.python.org/3/library/stdtypes.html#list)\[PagedCacheValues])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[MLAPrefillMetadata](#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata)]
## `LatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope} > class max.nn.legacy.attention.multi\_latent\_attention.LatentAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, dtype, devices=None, linear\_cls=\, o\_proj\_dtype=None, o\_proj\_float8\_config=None, scale=None, q\_lora\_rank=None, kv\_lora\_rank=512, qk\_nope\_head\_dim=128, qk\_rope\_head\_dim=64, v\_head\_dim=128, buffer\_size=16384, graph\_mode=None) Implementation of Latent Attention with Rope.
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) – The rope layer to borrow the freqs\_cis value from. * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – The number of attention heads. * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of key/value heads. * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension of the hidden states. * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV Cache Params, including the number of kv heads, the head dim, and data type. * dtype ([DType](../../../dtype.md#max.dtype.DType)) – DType of the weights, currently only bfloat16 is supported. * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used. * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) – Linear class to use for the outputs dense layer. * o\_proj\_dtype ([DType](../../../dtype.md#max.dtype.DType) | None) – Optional dtype override for the output projection. * o\_proj\_float8\_config ([Float8Config](../float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) – Optional float8 config for the output projection. * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) – Value used to scale the results of the attention output. * q\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int) | None) – Optional LoRA rank for Q projection. * kv\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – LoRA rank for KV projections. * qk\_nope\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for non-positional encoding part. * qk\_rope\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for rope part. * v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for value. * buffer\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Buffer size for storing the temporal results during prefill, in unit of tokens. * graph\_mode ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
### `create_mla_prefill_metadata()` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.create_mla_prefill_metadata} > create\_mla\_prefill\_metadata(input\_row\_offsets, kv\_collection)
**Parameters:**
* input\_row\_offsets ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues)
**Return type:**
[MLAPrefillMetadata](#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata)
### `rope` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding) ### `shard()` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.shard} > shard(devices) Creates sharded views of this Module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded LatentAttentionWithRope instances, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[LatentAttentionWithRope](#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope)]
### `sharding_strategy` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the Module sharding strategy. ### `w_uk_uv` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.w_uk_uv} > property w\_uk\_uv: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)] The concatenation of q, k, and v weight vectors. ## `MLAPrefillMetadata` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata} > class max.nn.legacy.attention.multi\_latent\_attention.MLAPrefillMetadata(buffer\_row\_offsets, cache\_offsets, buffer\_lengths) Dataclass to hold MLA prefill metadata.
**Parameters:**
* buffer\_row\_offsets ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)) * cache\_offsets ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)) * buffer\_lengths ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue))
### `buffer_lengths` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.buffer_lengths} > buffer\_lengths: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) ### `buffer_row_offsets` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.buffer_row_offsets} > buffer\_row\_offsets: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) ### `cache_offsets` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.cache_offsets} > cache\_offsets: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) ## `TensorParallelLatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope} > class max.nn.legacy.attention.multi\_latent\_attention.TensorParallelLatentAttentionWithRope(\*\*kwargs) Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient. ### `create_mla_prefill_metadata()` {#max.nn.legacy.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope.create_mla_prefill_metadata} > create\_mla\_prefill\_metadata(input\_row\_offsets\_, kv\_collections)
**Parameters:**
* input\_row\_offsets\_ ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]) * kv\_collections ([list](https://docs.python.org/3/library/stdtypes.html#list)\[PagedCacheValues])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[MLAPrefillMetadata](#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata)]
--- ## multihead_attention ## `MultiheadAttention` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention} > class max.nn.legacy.attention.multihead\_attention.MultiheadAttention(num\_attention\_heads, hidden\_size, devices=None, dtype=float32, scale=None, qkv\_has\_bias=False, o\_proj\_has\_bias=False, stacked\_qkv=False) Multihead attention that handles both single and distributed computation.
**Parameters:**
* num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * devices (Sequence\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * qkv\_has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * o\_proj\_has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool))
### `wqkv` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention.wqkv} > property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ### `wqkv_bias` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention.wqkv_bias} > property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The concatenation of q, k, and v bias weight vectors. --- ## ragged_attention An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel. ## `RaggedAttention` {#max.nn.legacy.attention.ragged_attention.RaggedAttention} > class max.nn.legacy.attention.ragged\_attention.RaggedAttention(\*, mask\_variant, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, clip\_qkv=None) Layer that computes the self attention score for ragged inputs.
**Parameters:**
* mask\_variant ([MHAMaskVariant](mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None)
### `wqkv` {#max.nn.legacy.attention.ragged_attention.RaggedAttention.wqkv} > property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. --- ## clamp ## `clamp()` {#max.nn.legacy.clamp.clamp} > max.nn.legacy.clamp.clamp(x, min=None, max=None) Clamps values in `x` to `[min, max]`.
**Parameters:**
* x ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor to clamp. * min ([float](https://docs.python.org/3/library/functions.html#float) | None) – Minimum value. If `None`, no lower bound is applied. * max ([float](https://docs.python.org/3/library/functions.html#float) | None) – Maximum value. If `None`, no upper bound is applied.
**Returns:**
Clamped tensor.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
--- ## comm ## `Allreduce` {#max.nn.legacy.comm.Allreduce} > class max.nn.legacy.comm.Allreduce(num\_accelerators) Layer to perform allreduce operation with automatic implementation selection. Automatically chooses between peer-to-peer optimized allreduce and naive device-to-device transfer based on accelerator connectivity.
**Parameters:**
num\_accelerators ([int](https://docs.python.org/3/library/functions.html#int)) – Number of accelerators participating in the allreduce operation
### `devices` {#max.nn.legacy.comm.Allreduce.devices} > devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Accelerator](../../../driver.md#max.driver.Accelerator)] List of accelerators involved in the allreduce operation. ## `Signals` {#max.nn.legacy.comm.Signals} > class max.nn.legacy.comm.Signals(devices) Signal buffers used for peer-to-peer communication in allreduce. Device code uses these buffers by enabling peer-to-peer access. Then thread blocks use the buffers to implement barriers for synchronization, and to hold intermediate communication results.
**Parameters:**
devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)])
### `NUM_BYTES` {#max.nn.legacy.comm.Signals.NUM_BYTES} > NUM\_BYTES = 537919488 The size of the signal buffers used for communication in allreduce. ### `buffers()` {#max.nn.legacy.comm.Signals.buffers} > buffers() Allocates and returns buffers used for communication in allreduce. Synchronizes so that buffers are ready for use when this method returns.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../../../driver.md#max.driver.Buffer)]
### `devices` {#max.nn.legacy.comm.Signals.devices} > devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)] List of graph devices that these signals communicate between. ### `input_types()` {#max.nn.legacy.comm.Signals.input_types} > input\_types() Gets graph input types corresponding to these signal buffers.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[BufferType](../../../graph/type.md#max.graph.type.BufferType)]
--- ## conv The `conv` module provides classes for performing convolution operations in various dimensions (1D, 2D, and 3D) on tensor inputs. These convolution operations are core building blocks for neural networks, especially in computer vision and sequence processing tasks. Here’s an example demonstrating how to use a 1D convolution: ```python import max.nn as nn from max.graph import Graph, ops, Weight, DeviceRef from max.dtype import DType import numpy as np with Graph(name="conv_example") as graph: # Define dimensions batch_size = 2 seq_length = 10 in_channels = 16 out_channels = 32 kernel_size = 3 # Create input tensor [batch_size, sequence_length, channels] x_data = np.zeros((batch_size, seq_length, in_channels), dtype=np.float32) x = ops.constant(x_data, dtype=DType.float32, device=DeviceRef.CPU()) # Create weights for convolution filter_1d = Weight( name="filter_weight", dtype=DType.float32, shape=[kernel_size, in_channels, out_channels] device=DeviceRef.CPU() ) bias_1d = Weight( name="bias_weight", dtype=DType.float32, shape=[out_channels] device=DeviceRef.CPU() ) # Create and apply Conv1D layer conv1d = nn.Conv1D( filter=filter_1d, bias=bias_1d, stride=1, padding=1 ) output_1d = conv1d(x) print(f"Conv1D output shape: {output_1d.shape}") # Output: Conv1D output shape: [Dim(2), Dim(10), Dim(32)] ``` ## `Conv1D` {#max.nn.legacy.conv.Conv1D} > class max.nn.legacy.conv.Conv1D(kernel\_size, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 1D convolution over an input signal composed of several input planes. **Example:** ```python conv = nn.Conv1D( kernel_size=3, in_channels=64, out_channels=128, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv1d_weight", device=DeviceRef.GPU(), ) ```
**Parameters:**
* kernel\_size ([int](https://docs.python.org/3/library/functions.html#int)) * in\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * out\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * stride ([int](https://docs.python.org/3/library/functions.html#int)) * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([int](https://docs.python.org/3/library/functions.html#int)) * num\_groups ([int](https://docs.python.org/3/library/functions.html#int)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * permute ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `bias` {#max.nn.legacy.conv.Conv1D.bias} > bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv1D.device) if present. ### `device` {#max.nn.legacy.conv.Conv1D.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.legacy.conv.Conv1D.dilation} > dilation: [int](https://docs.python.org/3/library/functions.html#int) Controls the dilation rate. ### `filter` {#max.nn.legacy.conv.Conv1D.filter} > filter: [Weight](../../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (kernel\_size, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv1D.device). ### `num_groups` {#max.nn.legacy.conv.Conv1D.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Number of blocked connections from input channels to output channels. ### `padding` {#max.nn.legacy.conv.Conv1D.padding} > padding: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied to the input. If int: symmetric padding applied to both sides (pad\_left = pad\_right = padding). If tuple\[int, int]: asymmetric padding as (pad\_left, pad\_right). ### `permute` {#max.nn.legacy.conv.Conv1D.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, kernel\_size) Max API order: (kernel\_size, in\_channels / num\_groups, out\_channels). ### `stride` {#max.nn.legacy.conv.Conv1D.stride} > stride: [int](https://docs.python.org/3/library/functions.html#int) Controls the stride for the cross-correlation. ## `Conv2d` {#max.nn.legacy.conv.Conv2d} > class max.nn.legacy.conv.Conv2d(kernel\_size, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 2D convolution over an input signal composed of several input planes. **Example:** ```python conv = nn.Conv2d( kernel_size=3, in_channels=64, out_channels=128, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv2d_weight", device=DeviceRef.GPU(), ) ```
**Parameters:**
* kernel\_size ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * in\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * out\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * num\_groups ([int](https://docs.python.org/3/library/functions.html#int)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * permute ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `bias` {#max.nn.legacy.conv.Conv2d.bias} > bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv2d.device) if present. ### `device` {#max.nn.legacy.conv.Conv2d.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.legacy.conv.Conv2d.dilation} > dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the dilation rate. ### `filter` {#max.nn.legacy.conv.Conv2d.filter} > filter: [Weight](../../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (height, width, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv2d.device). ### `num_groups` {#max.nn.legacy.conv.Conv2d.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Number of blocked connections from input channels to output channels. ### `padding` {#max.nn.legacy.conv.Conv2d.padding} > padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied before and after the input for height and width dimensions. Format: (pad\_top, pad\_bottom, pad\_left, pad\_right). ### `permute` {#max.nn.legacy.conv.Conv2d.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, height, width) Max API order: (height, width, in\_channels / num\_groups, out\_channels). ### `shard()` {#max.nn.legacy.conv.Conv2d.shard} > shard(devices) Creates sharded views of this Conv2d layer across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded Conv2d instances, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Conv2d](#max.nn.legacy.conv.Conv2d)]
### `sharding_strategy` {#max.nn.legacy.conv.Conv2d.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the Conv2d sharding strategy. ### `stride` {#max.nn.legacy.conv.Conv2d.stride} > stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the stride for the cross-correlation. ## `Conv3D` {#max.nn.legacy.conv.Conv3D} > class max.nn.legacy.conv.Conv3D(depth, height, width, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 3D convolution over an input signal composed of several input planes. **Example:** ```python conv = nn.Conv3D( depth=3, height=3, width=3, in_channels=64, out_channels=128, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv3d_weight", device=DeviceRef.GPU(), ) ```
**Parameters:**
* depth ([int](https://docs.python.org/3/library/functions.html#int)) * height ([int](https://docs.python.org/3/library/functions.html#int)) * width ([int](https://docs.python.org/3/library/functions.html#int)) * in\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * out\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * num\_groups ([int](https://docs.python.org/3/library/functions.html#int)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * permute ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `bias` {#max.nn.legacy.conv.Conv3D.bias} > bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv3D.device) if present. ### `device` {#max.nn.legacy.conv.Conv3D.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.legacy.conv.Conv3D.dilation} > dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the dilation rate for depth, height, and width dimensions. ### `filter` {#max.nn.legacy.conv.Conv3D.filter} > filter: [Weight](../../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (depth, height, width, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv3D.device). ### `num_groups` {#max.nn.legacy.conv.Conv3D.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Number of blocked connections from input channels to output channels. ### `padding` {#max.nn.legacy.conv.Conv3D.padding} > padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied before and after the input for depth, height, and width dimensions. Format: (pad\_front, pad\_back, pad\_top, pad\_bottom, pad\_left, pad\_right). ### `permute` {#max.nn.legacy.conv.Conv3D.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, depth, height, width) Max API order: (depth, height, width, in\_channels / num\_groups, out\_channels). ### `stride` {#max.nn.legacy.conv.Conv3D.stride} > stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the stride for the cross-correlation. --- ## conv_transpose ## `ConvTranspose1d` {#max.nn.legacy.conv_transpose.ConvTranspose1d} > class max.nn.legacy.conv\_transpose.ConvTranspose1d(length, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, output\_padding=0, device=None, has\_bias=False, permute=False, name=None) A 1D transposed convolution operator over an input image composed of several input planes. ```python conv = nn.ConvTranspose1d( in_channels, out_channels, kernel_size, stride, padding, output_padding, has_bias=False, name="conv3d_weight", device=DeviceRef.GPU(), ) ```
**Parameters:**
* length ([int](https://docs.python.org/3/library/functions.html#int)) * in\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * out\_channels ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * output\_padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * permute ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `bias` {#max.nn.legacy.conv_transpose.ConvTranspose1d.bias} > bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.legacy.conv_transpose.ConvTranspose1d.device) if present. ### `device` {#max.nn.legacy.conv_transpose.ConvTranspose1d.device} > device: [DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.legacy.conv_transpose.ConvTranspose1d.dilation} > dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Not implemented yet. Assuming dilation = 1 for now. ### `output_padding` {#max.nn.legacy.conv_transpose.ConvTranspose1d.output_padding} > output\_padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] 0
**Type:**
Additional size added to one side of the output shape. Default
### `padding` {#max.nn.legacy.conv_transpose.ConvTranspose1d.padding} > padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied before and after the input for depth, height, and width dimensions. ### `permute` {#max.nn.legacy.conv_transpose.ConvTranspose1d.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) bool controls whether self.weight is permuted from PyTorch order to max order. PyTorch order is: (in\_channels, out\_channels, kernel\_length) Max API order: (kernel\_length, out\_channels, in\_channels). ### `stride` {#max.nn.legacy.conv_transpose.ConvTranspose1d.stride} > stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the stride for the cross-correlation. ### `weight` {#max.nn.legacy.conv_transpose.ConvTranspose1d.weight} > weight: [Weight](../../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (kernel\_length, out\_channels, in\_channels). Model init moves the weight to [`device`](#max.nn.legacy.conv_transpose.ConvTranspose1d.device). ## `WeightNormConvTranspose1d` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d} > class max.nn.legacy.conv\_transpose.WeightNormConvTranspose1d(length, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, output\_padding=0, device=None, has\_bias=False, permute=False, name=None) A 1D transposed convolution operator over an input image composed of several input planes. This version uses weight normalization as described in . Weight normalization reparameterizes weights in terms of a direction vector `v` and a magnitude scalar `g`. This can help improve optimization by decoupling the length and direction of weight vectors. For example: : \`\`\`python conv = WeightNormConvTranspose1d( length=kernel\_size, in\_channels=in\_channels, out\_channels=out\_channels, dtype=dtype, stride=stride, padding=padding, output\_padding=output\_padding, has\_bias=False, device=DeviceRef.GPU(), ) ```
**Parameters:**
* length ([int](https://docs.python.org/3/library/functions.html#int)) * in_channels ([int](https://docs.python.org/3/library/functions.html#int)) * out_channels ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * output_padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | None) * has_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * permute ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `conv` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.conv} > conv: [ConvTranspose1d](#max.nn.legacy.conv_transpose.ConvTranspose1d) The underlying ConvTranspose1d layer. ### `device` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.device} > device: [DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `weight_g` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.weight_g} > weight_g: [Weight](../../graph/Weight.md#max.graph.Weight) The magnitude parameter g for weight normalization. ### `weight_v` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.weight_v} > weight_v: [Weight](../../graph/Weight.md#max.graph.Weight) The direction parameter v for weight normalization. ``` --- ## embedding (Legacy) The `embedding` module provides classes for mapping integer indices (like token IDs) to dense vector representations. These embedding operations are fundamental building blocks for natural language processing, recommendation systems, and other tasks involving discrete tokens. * `Embedding`: Basic embedding lookup table for simple use cases * `EmbeddingV2`: Enhanced embedding with device placement control and improved memory management * `VocabParallelEmbedding`: Distributed embedding that shards the vocabulary across multiple devices for large embedding tables Here’s an example demonstrating how to use embeddings: ```python import max.nn as nn from max.graph import Graph, ops, DeviceRef from max.dtype import DType import numpy as np with Graph(name="embedding_example") as graph: # Define dimensions batch_size = 4 seq_length = 16 vocab_size = 10000 hidden_dim = 256 # Create input tensor of token indices input_data = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32) input_indices = ops.constant(input_data, dtype=DType.int32, device=DeviceRef.CPU()) # Create embedding layer embedding = nn.EmbeddingV2( vocab_size=vocab_size, hidden_dim=hidden_dim, dtype=DType.float32, device=DeviceRef.GPU(), name="token_embeddings" ) # Look up embeddings for input indices embeddings = embedding(input_indices) print(f"Embedding output shape: {embeddings.shape}") # Embedding output shape: [Dim(4), Dim(16), Dim(256)] ``` ## `Embedding` {#max.nn.legacy.embedding.Embedding} > class max.nn.legacy.embedding.Embedding(vocab\_size, hidden\_dim, dtype, device, quantization\_encoding=None, name=None) A lookup table for embedding integer indices into dense vectors. This layer maps each integer index to a dense vector of fixed size. Embedding weights are stored on the CPU but are moved to the specified device during the model init phase. Example: ```python embedding_layer = Embedding( vocab_size=1000, hidden_dim=256, dtype=DType.float32, device=DeviceRef.GPU(), name="embeddings", ) token_indices: TensorValueLike embeddings = embedding_layer(token_indices) ```
**Parameters:**
* vocab\_size ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `device` {#max.nn.legacy.embedding.Embedding.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) The device on which embedding lookup is performed. ### `weight` {#max.nn.legacy.embedding.Embedding.weight} > weight: [Weight](../../graph/Weight.md#max.graph.Weight) The embedding weight matrix stored on the CPU. Model init moves weights to the device specified in [`device`](#max.nn.legacy.embedding.Embedding.device). ## `VocabParallelEmbedding` {#max.nn.legacy.embedding.VocabParallelEmbedding} > class max.nn.legacy.embedding.VocabParallelEmbedding(vocab\_size, hidden\_dim, dtype, devices, quantization\_encoding=None, name=None) A lookup table for embedding integer indices into dense vectors. This layer works like nn.Embedding except the embedding table is sharded on the vocabulary dimension across all devices. Example: ```python embedding_layer = VocabParallelEmbedding( vocab_size=1000, hidden_dim=256, dtype=DType.float32, device=[DeviceRef.GPU(0), DeviceRef.GPU(1)], name="embeddings", ) # Token indices of shape: [batch, ..., num_indices]. token_indices: TensorValueLike embeddings = embedding_layer(token_indices) ```
**Parameters:**
* vocab\_size ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
--- ## float8_config Float8 configuration data structures for models. ## `Float8Config` {#max.nn.legacy.float8_config.Float8Config} > class max.nn.legacy.float8\_config.Float8Config(input\_scale, weight\_scale, mlp\_in\_float8, attn\_qkv\_in\_float8, embedding\_output\_dtype=None, bias\_dtype=None, quant\_method=None, quant\_algo=None) Configures float8 quantization settings for a layer or model section.
**Parameters:**
* input\_scale ([Float8InputScaleSpec](#max.nn.legacy.float8_config.Float8InputScaleSpec)) * weight\_scale ([Float8WeightScaleSpec](#max.nn.legacy.float8_config.Float8WeightScaleSpec)) * mlp\_in\_float8 ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) * attn\_qkv\_in\_float8 ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) * embedding\_output\_dtype ([DType](../../dtype.md#max.dtype.DType) | None) * bias\_dtype ([DType](../../dtype.md#max.dtype.DType) | None) * quant\_method ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * quant\_algo ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `attn_qkv_in_float8` {#max.nn.legacy.float8_config.Float8Config.attn_qkv_in_float8} > attn\_qkv\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] Set of layer indices with attention QKV projections in float8. QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}\_proj are float8, or all bfloat16. ### `bias_dtype` {#max.nn.legacy.float8_config.Float8Config.bias_dtype} > bias\_dtype: [DType](../../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None The `DType` of bias weights. ### `embedding_output_dtype` {#max.nn.legacy.float8_config.Float8Config.embedding_output_dtype} > embedding\_output\_dtype: [DType](../../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None The `DType` of the output from the embedding layer. ### `input_scale` {#max.nn.legacy.float8_config.Float8Config.input_scale} > input\_scale: [Float8InputScaleSpec](#max.nn.legacy.float8_config.Float8InputScaleSpec) [`Float8InputScaleSpec`](#max.nn.legacy.float8_config.Float8InputScaleSpec) for input activation scaling. ### `is_dynamic` {#max.nn.legacy.float8_config.Float8Config.is_dynamic} > property is\_dynamic: [bool](https://docs.python.org/3/library/functions.html#bool) Returns `True` if this input scale is dynamic. ### `is_nvfp4` {#max.nn.legacy.float8_config.Float8Config.is_nvfp4} > property is\_nvfp4: [bool](https://docs.python.org/3/library/functions.html#bool) Returns `True` if this config represents modelopt NVFP4. ### `is_static` {#max.nn.legacy.float8_config.Float8Config.is_static} > property is\_static: [bool](https://docs.python.org/3/library/functions.html#bool) Returns `True` if this input scale is static. ### `mlp_in_float8` {#max.nn.legacy.float8_config.Float8Config.mlp_in_float8} > mlp\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] Set of layer indices with MLPs in float8. MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16. ### `quant_algo` {#max.nn.legacy.float8_config.Float8Config.quant_algo} > quant\_algo: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None Additional differentiator within same quant\_method e.g. modelopt NVFP4 vs FP8 ### `quant_method` {#max.nn.legacy.float8_config.Float8Config.quant_method} > quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None The quantization method used (e.g., “fbgemm\_fp8”). ### `quantized_scales_type()` {#max.nn.legacy.float8_config.Float8Config.quantized_scales_type} > quantized\_scales\_type(quantized\_shape, device\_ref) Returns the TensorType of the scales tensor after dynamic quantization.
**Parameters:**
* quantized\_shape ([Shape](../../graph/shape.md#max.graph.shape.Shape)) * device\_ref ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[TensorType](../../graph/type.md#max.graph.type.TensorType)
### `scales_granularity_mnk` {#max.nn.legacy.float8_config.Float8Config.scales_granularity_mnk} > property scales\_granularity\_mnk: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Returns the weight and input scale granularities on M, N and K axis. ### `weight_scale` {#max.nn.legacy.float8_config.Float8Config.weight_scale} > weight\_scale: [Float8WeightScaleSpec](#max.nn.legacy.float8_config.Float8WeightScaleSpec) [`Float8WeightScaleSpec`](#max.nn.legacy.float8_config.Float8WeightScaleSpec) for weight scaling. ## `Float8InputScaleSpec` {#max.nn.legacy.float8_config.Float8InputScaleSpec} > class max.nn.legacy.float8\_config.Float8InputScaleSpec(granularity, origin, dtype, activation\_scale\_ub=None, block\_size=None) Specifies how input activations are scaled for float8 quantization.
**Parameters:**
* granularity ([Float8ScaleGranularity](#max.nn.legacy.float8_config.Float8ScaleGranularity)) * origin ([Float8ScaleOrigin](#max.nn.legacy.float8_config.Float8ScaleOrigin)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * activation\_scale\_ub ([float](https://docs.python.org/3/library/functions.html#float) | None) * block\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] | None)
### `activation_scale_ub` {#max.nn.legacy.float8_config.Float8InputScaleSpec.activation_scale_ub} > activation\_scale\_ub: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None An optional upper bound for dynamic activation scaling. ### `block_size` {#max.nn.legacy.float8_config.Float8InputScaleSpec.block_size} > block\_size: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None The `tuple[int, int]` of the block size for block-wise scaling. ### `dtype` {#max.nn.legacy.float8_config.Float8InputScaleSpec.dtype} > dtype: [DType](../../dtype.md#max.dtype.DType) The `DType` of the input scale factor(s). ### `granularity` {#max.nn.legacy.float8_config.Float8InputScaleSpec.granularity} > granularity: [Float8ScaleGranularity](#max.nn.legacy.float8_config.Float8ScaleGranularity) The [`Float8ScaleGranularity`](#max.nn.legacy.float8_config.Float8ScaleGranularity) of the input scale factor application. ### `is_block` {#max.nn.legacy.float8_config.Float8InputScaleSpec.is_block} > property is\_block: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the input scale granularity is block-wise. ### `is_colwise` {#max.nn.legacy.float8_config.Float8InputScaleSpec.is_colwise} > property is\_colwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the input scale granularity is column-wise. ### `is_rowwise` {#max.nn.legacy.float8_config.Float8InputScaleSpec.is_rowwise} > property is\_rowwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the input scale granularity is row-wise. ### `is_tensor` {#max.nn.legacy.float8_config.Float8InputScaleSpec.is_tensor} > property is\_tensor: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the input scale granularity is per-tensor. ### `origin` {#max.nn.legacy.float8_config.Float8InputScaleSpec.origin} > origin: [Float8ScaleOrigin](#max.nn.legacy.float8_config.Float8ScaleOrigin) The [`Float8ScaleOrigin`](#max.nn.legacy.float8_config.Float8ScaleOrigin) (static or dynamic) of the input scale factor. ## `Float8ScaleGranularity` {#max.nn.legacy.float8_config.Float8ScaleGranularity} > class max.nn.legacy.float8\_config.Float8ScaleGranularity(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Specifies the granularity of the quantization scale factor. Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor. ### `BLOCK` {#max.nn.legacy.float8_config.Float8ScaleGranularity.BLOCK} > BLOCK = 'block' Per-block scaling. ### `COLWISE` {#max.nn.legacy.float8_config.Float8ScaleGranularity.COLWISE} > COLWISE = 'colwise' Per-column scaling. ### `ROWWISE` {#max.nn.legacy.float8_config.Float8ScaleGranularity.ROWWISE} > ROWWISE = 'rowwise' Per-row scaling. ### `TENSOR` {#max.nn.legacy.float8_config.Float8ScaleGranularity.TENSOR} > TENSOR = 'tensor' Per-tensor scaling. ## `Float8ScaleOrigin` {#max.nn.legacy.float8_config.Float8ScaleOrigin} > class max.nn.legacy.float8\_config.Float8ScaleOrigin(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Specifies whether the quantization scale is determined statically or dynamically. ### `DYNAMIC` {#max.nn.legacy.float8_config.Float8ScaleOrigin.DYNAMIC} > DYNAMIC = 'dynamic' Scales are computed at runtime based on the input data. ### `STATIC` {#max.nn.legacy.float8_config.Float8ScaleOrigin.STATIC} > STATIC = 'static' Scales are pre-computed and loaded with the model weights. ## `Float8WeightScaleSpec` {#max.nn.legacy.float8_config.Float8WeightScaleSpec} > class max.nn.legacy.float8\_config.Float8WeightScaleSpec(granularity, dtype, block\_size=None) Specifies how weights are scaled for float8 quantization.
**Parameters:**
* granularity ([Float8ScaleGranularity](#max.nn.legacy.float8_config.Float8ScaleGranularity)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * block\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] | None)
### `block_size` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.block_size} > block\_size: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None The `tuple[int, int]` of the block size for block-wise scaling. ### `dtype` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.dtype} > dtype: [DType](../../dtype.md#max.dtype.DType) The `DType` of the weight scale factor(s). ### `granularity` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.granularity} > granularity: [Float8ScaleGranularity](#max.nn.legacy.float8_config.Float8ScaleGranularity) The [`Float8ScaleGranularity`](#max.nn.legacy.float8_config.Float8ScaleGranularity) of the weight scale factor application. ### `is_block` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.is_block} > property is\_block: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is block-wise. ### `is_colwise` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.is_colwise} > property is\_colwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is column-wise. ### `is_rowwise` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.is_rowwise} > property is\_rowwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is row-wise. ### `is_tensor` {#max.nn.legacy.float8_config.Float8WeightScaleSpec.is_tensor} > property is\_tensor: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is per-tensor. ## `ceildiv()` {#max.nn.legacy.float8_config.ceildiv} > max.nn.legacy.float8\_config.ceildiv(n, d)
**Parameters:**
* n ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) * d ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)])
**Return type:**
[Dim](../../graph/dim.md#max.graph.dim.Dim)
## `nvfp4_packed_k()` {#max.nn.legacy.float8_config.nvfp4_packed_k} > max.nn.legacy.float8\_config.nvfp4\_packed\_k(in\_dim, float8\_config) Returns packed K dimension for NVFP4 weights, else returns in\_dim.
**Parameters:**
* in\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * float8\_config ([Float8Config](#max.nn.legacy.float8_config.Float8Config) | None)
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
--- ## hooks ## `PrintHook` {#max.nn.legacy.hooks.PrintHook} > class max.nn.legacy.hooks.PrintHook(export\_path=None, filter=None) Hook that prints/saves layer tensor inputs and outputs. This class must be initialized added before the graph is built so the print ops can be added to the graph.
**Parameters:**
* export\_path ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * filter ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | None)
### `export_path` {#max.nn.legacy.hooks.PrintHook.export_path} > property export\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `name_layers()` {#max.nn.legacy.hooks.PrintHook.name_layers} > name\_layers(model) Create names for all layers in the model based on nested attributes.
**Parameters:**
model ([Layer](layer.md#max.nn.legacy.layer.Layer))
**Return type:**
None
### `print_value()` {#max.nn.legacy.hooks.PrintHook.print_value} > print\_value(name, value) Prints a value, and returns whether the print is successful.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * value ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `remove()` {#max.nn.legacy.hooks.PrintHook.remove} > remove()
**Return type:**
None
--- ## legacy Legacy graph-based neural network API. :::note Note This is the legacy API for backward compatibility. For all new models, use the eager tensor API from [nn](/max/api/python/nn). ::: The legacy API provides graph-based layer implementations for building neural networks. This API was the primary interface prior to MAX 26.1 and remains available for backward compatibility. **Using the Legacy API:** ```python from max.nn.legacy import Module, Linear, LayerNorm from max.nn.legacy.attention import AttentionWithRope ``` ## Modules * [`attention`](/max/api/python/nn/legacy/attention): Attention mechanisms for sequence modeling. * [`clamp`](/max/api/python/nn/legacy/clamp): Value clamping utilities for tensor operations. * [`comm`](/max/api/python/nn/legacy/comm): Communication primitives for distributed training. * [`conv`](/max/api/python/nn/legacy/conv): Convolutional layers for spatial processing. * [`conv_transpose`](/max/api/python/nn/legacy/conv_transpose): Transposed convolution for upsampling. * [`embedding`](/max/api/python/nn/legacy/embedding): Embedding layers with vocabulary support. * [`float8_config`](/max/api/python/nn/legacy/float8_config): Configuration for FP8 quantization. * [`hooks`](/max/api/python/nn/legacy/hooks): Extension hooks for layer customization. * [`kernels`](/max/api/python/nn/legacy/kernels): Custom kernel implementations. * [`kv_cache`](/max/api/python/nn/legacy/kv_cache): Key-value cache for efficient generation. * [`layer`](/max/api/python/nn/legacy/layer): Base classes for building graph-based layers. * [`linear`](/max/api/python/nn/legacy/linear): Linear transformation layers with optional parallelism. * [`lora`](/max/api/python/nn/legacy/lora): Low-Rank Adaptation for efficient fine-tuning. * [`moe`](/max/api/python/nn/legacy/moe): Mixture of Experts layer implementations. * [`norm`](/max/api/python/nn/legacy/norm): Normalization layers for training stability. * [`rotary_embedding`](/max/api/python/nn/legacy/rotary_embedding): Rotary position embeddings for sequences. * [`sampling`](/max/api/python/nn/legacy/sampling): Sampling strategies for generation. * [`sequential`](/max/api/python/nn/legacy/sequential): Container for sequential layer composition. * [`transformer`](/max/api/python/nn/legacy/transformer): Transformer building blocks and layers. --- ## kernels Helper functions for wrapping custom kv cache/attention related ops. ## `apply_penalties_to_logits()` {#max.nn.legacy.kernels.apply_penalties_to_logits} > max.nn.legacy.kernels.apply\_penalties\_to\_logits(logits\_buffer, frequency\_data, frequency\_offsets, \*, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0) Applies penalties to the logits.
**Parameters:**
* logits\_buffer ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – The buffer to apply penalties to. * frequency\_data ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 2d tensor of shape \[unique\_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token. * frequency\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 1d tensor of shape \[batch\_size + 1], indicating start of each sequence’s data. * frequency\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances. * presence\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The presence penalty to apply to the model’s output A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty. * repetition\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in prompt and generated text at least once by dividing the logits by the repetition penalty.
**Return type:**
None
## `batched_dynamic_scaled_fp8_matmul()` {#max.nn.legacy.kernels.batched_dynamic_scaled_fp8_matmul} > max.nn.legacy.kernels.batched\_dynamic\_scaled\_fp8\_matmul(a, b, a\_scales, b\_scales, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16) Perform a batched blockwise scaled matmul of two tensors with scaling factors.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply (3D tensor). * b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed (3D tensor). * a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor (3D tensor). * b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor (3D tensor). * input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) * weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec)) * out\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The result of the matmul operation.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `batched_quantize_dynamic_scaled_float8()` {#max.nn.legacy.kernels.batched_quantize_dynamic_scaled_float8} > max.nn.legacy.kernels.batched\_quantize\_dynamic\_scaled\_float8(input, input\_scale\_spec, weight\_scale\_spec, scale\_ub=1200.0, group\_size\_or\_per\_token=-1, out\_type=float8\_e4m3fn, scales\_type=bfloat16) Dynamically quantize the input tensor to fp8.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize. Shape: [batch\_size, seq\_len, hidden\_size] * scale\_ub ([float](https://docs.python.org/3/library/functions.html#float)) – The upper bound of the scale factor. * group\_size\_or\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The group size for quantization. When set to -1, the quantization is column-wise. * out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor. * scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor. * input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) * weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
**Returns:**
The quantized tensor and the scales.
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `block_scales_interleave()` {#max.nn.legacy.kernels.block_scales_interleave} > max.nn.legacy.kernels.block\_scales\_interleave(scales, sf\_vector\_size=16, scales\_type=float8\_e4m3fn) Interleave the block scales tensor in \[M, N] layout to \[ceildiv(M, 128), ceildiv(N, sf\_vector\_size \* 4), 32, 4, 4] layout.
**Parameters:**
* scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scales tensor to interleave in \[M, N] layout. * sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The block size for the scaling factors. * scales\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The interleaved scales tensor in \[ceildiv(M, 128), ceildiv(N, sf\_vector\_size \* 4), 32, 4, 4] layout.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `ceildiv()` {#max.nn.legacy.kernels.ceildiv} > max.nn.legacy.kernels.ceildiv(n, d) Ceiling division.
**Parameters:**
* n ([Dim](../../graph/dim.md#max.graph.dim.Dim)) – The numerator. * d ([Dim](../../graph/dim.md#max.graph.dim.Dim)) – The denominator.
**Returns:**
The ceiling of dividing n by d.
**Return type:**
[Dim](../../graph/dim.md#max.graph.dim.Dim)
## `convert_weights_to_fp8_fnuz_if_needed()` {#max.nn.legacy.kernels.convert_weights_to_fp8_fnuz_if_needed} > max.nn.legacy.kernels.convert\_weights\_to\_fp8\_fnuz\_if\_needed(weight, weight\_scale) Convert weights and scales to FP8 FNUZ format if needed for AMD GPUs. This utility function checks if FP8 FNUZ conversion is needed, currently onli AMD MI300 GPUs, and performs the conversion if required. This centralizes the conversion logic that was previously duplicated across multiple files.
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight tensor to potentially convert. * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight scale factor.
**Returns:**
Tuple of (weight, weight\_scale) - converted if needed, original otherwise.
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `cross_attention_ragged()` {#max.nn.legacy.kernels.cross_attention_ragged} > max.nn.legacy.kernels.cross\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, kv\_input\_row\_offsets, q\_max\_seq\_len, scale, local\_window\_size=-1) Computes cross attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input attention, kv\_input\_row\_offsets represents the KV sequence length.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) * kv\_input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * q\_max\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * scale ([float](https://docs.python.org/3/library/functions.html#float)) * local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `dynamic_block_scaled_matmul_fp4()` {#max.nn.legacy.kernels.dynamic_block_scaled_matmul_fp4} > max.nn.legacy.kernels.dynamic\_block\_scaled\_matmul\_fp4(a, b, a\_scales, b\_scales, tensor\_sf, sf\_vector\_size=16, out\_type=bfloat16) Perform a matmul of two FP4 tensors with 1D-block scaled scaling factors.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply. * b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed. * a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor. * b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor. * tensor\_sf ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [float](https://docs.python.org/3/library/functions.html#float)) – Buffer-wise scaling factor equal to weight\_scale\_2 \* input\_scale (non-inverted). * sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int)) * out\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The result of the matmul operation.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `dynamic_scaled_matmul()` {#max.nn.legacy.kernels.dynamic_scaled_matmul} > max.nn.legacy.kernels.dynamic\_scaled\_matmul(a, b, a\_scales, b\_scales, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16) Perform a matmul of two tensors with scaling factors. Currently only supports channel-wise scaling for weights and per-token scaling for inputs.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply. * b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed. * a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor. * b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor. * input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) * weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec)) * out\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The result of the matmul operation.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flare_mla_decode_ragged()` {#max.nn.legacy.kernels.flare_mla_decode_ragged} > max.nn.legacy.kernels.flare\_mla\_decode\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64) Computes flash (self) attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross\_attention\_ragged.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) * scale ([float](https://docs.python.org/3/library/functions.html#float)) * qk\_rope\_dim ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flare_mla_decompress_k_cache()` {#max.nn.legacy.kernels.flare_mla_decompress_k_cache} > max.nn.legacy.kernels.flare\_mla\_decompress\_k\_cache(kv\_params, buffer\_row\_offsets\_1d, cache\_offsets\_1d, buffer\_length, weight, kv\_collection, layer\_idx, buffer\_size) This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix. The process involves: 1. Copying buffer\_length latent vectors from the key cache into a contiguous buffer (k\_latent) 2. Computing k = k\_latent @ weight.T to obtain the decompressed keys
**Returns:**
A tensor of shape \[buffer\_size, weight.shape\[0]] containing the decompressed keys. Note that only the first buffer\_length tokens are valid.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * buffer\_row\_offsets\_1d ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * cache\_offsets\_1d ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * buffer\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * buffer\_size ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flare_mla_prefill_plan()` {#max.nn.legacy.kernels.flare_mla_prefill_plan} > max.nn.legacy.kernels.flare\_mla\_prefill\_plan(kv\_params, input\_row\_offsets, kv\_collection, layer\_idx, buffer\_size, max\_chunks=16) This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer. Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer\_size. For each chunk (iteration), it calculates: : 1. Buffer offsets for each sequence in each chunk 2\. Cache offsets for each sequence in each chunk 3\. Total buffer lengths for each processing iteration
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * buffer\_size ([int](https://docs.python.org/3/library/functions.html#int)) * max\_chunks ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `flare_mla_prefill_ragged()` {#max.nn.legacy.kernels.flare_mla_prefill_ragged} > max.nn.legacy.kernels.flare\_mla\_prefill\_ragged(kv\_params, input, k, v, input\_row\_offsets, buffer\_row\_offsets, cache\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64) Performs MLA prefill. In the MLA prefill, we need to decompress the KV tensors, as we store the latent representations in the KV cache. We will decompress the KV tensors into a fixed size buffer to avoid out-of-memory errors. In case the total cache length is greater than the buffer size, we will process the attention calculation in chunks. This MLA prefill kernel will return the output tensor for this iteration and the softmax info tensor for this iteration. Such tensors will be used by the next iteration of the MLA prefill kernel to continue the attention calculation.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor * k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor * v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in input * buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in the buffer * cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in the KV cache * kv\_collection (PagedCacheValues) – KV collection * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index tensor * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale * qk\_rope\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – QK rope dimension
**Returns:**
The output tensor for this iteration
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flash_attention_gpu()` {#max.nn.legacy.kernels.flash_attention_gpu} > max.nn.legacy.kernels.flash\_attention\_gpu(q, k, v, mask\_variant, scale, local\_window\_size=-1, valid\_length=None) Computes flash attention using GPU-optimized kernel.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – The mask variant to use for attention * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scaling factor for attention scores * local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Local window size for sliding window attention * valid\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional tensor of shape \[batch] with dtype uint32. When provided, uses the padded kernel variant that respects the valid sequence lengths for each batch element.
**Returns:**
Output tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flash_attention_padded_kv_cache()` {#max.nn.legacy.kernels.flash_attention_padded_kv_cache} > max.nn.legacy.kernels.flash\_attention\_padded\_kv\_cache(kv\_params, q, kv\_collection, layer\_idx, valid\_lengths, mask\_variant, scale, local\_window\_size=-1) Computes flash attention with padded inputs and paged KV cache.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters * q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * kv\_collection (PagedCacheValues) – Paged KV cache collection * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for cache lookup * valid\_lengths ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch] with dtype uint32 indicating actual (non-padded) sequence lengths for each batch element * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – The mask variant to use for attention * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scaling factor for attention scores * local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Local window size for sliding window attention
**Returns:**
Output tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flash_attention_ragged()` {#max.nn.legacy.kernels.flash_attention_ragged} > max.nn.legacy.kernels.flash\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, local\_window\_size=-1, sink\_weights=None) Computes flash (self) attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross\_attention\_ragged.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters. * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape \[total\_seq\_len, hidden\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each batch in the input tensor with shape \[batch\_size + 1]. * kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have dtype uint32. * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – MHAMaskVariant specifying the type of attention mask to use. * scale ([float](https://docs.python.org/3/library/functions.html#float)) – float value used to scale the attention scores. * local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – int specifying the size of the local attention window, default is -1 for no local window. * sink\_weights ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional tensor of shape \[num\_heads] containing learnable sink weights for each attention head.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `flash_attention_ragged_gpu()` {#max.nn.legacy.kernels.flash_attention_ragged_gpu} > max.nn.legacy.kernels.flash\_attention\_ragged\_gpu(q, k, v, input\_row\_offsets, max\_seq\_len, mask\_variant, scale, local\_window\_size=-1) Computes flash attention for ragged inputs using GPU-optimized kernel without a KV cache.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged) * k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged) * v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch\_size + 1] with dtype uint32. Indicates where each sequence starts and ends in the ragged tensors. The values should be a prefix sum (cumulative sum) of sequence lengths. * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – The mask variant to use for attention * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scaling factor for attention scores * local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Local window size for sliding window attention * max\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Returns:**
Output tensor of shape \[total\_seq\_len, num\_heads, head\_dim]
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `fused_qk_padded_rope()` {#max.nn.legacy.kernels.fused_qk_padded_rope} > max.nn.legacy.kernels.fused\_qk\_padded\_rope(kv\_params, input, kv\_collection, freqs\_cis, layer\_idx, valid\_lengths, interleaved=True) Computes fused query-key RoPE with padded inputs and paged KV cache. This function applies Rotary Positional Embeddings (RoPE) to both Q and K tensors, where K is stored in the paged KV cache. This is the padded equivalent of fused\_qk\_ragged\_rope.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters. * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[batch, seq\_len, n\_heads, head\_dim]. * kv\_collection (PagedCacheValues) – Paged KV cache collection. * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Frequency tensor of shape (max\_seq\_len \* 2, head\_dim). * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for KV cache (must be uint32 on CPU). * valid\_lengths ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch] containing the valid length for each sequence (must be uint32). RoPE is only applied to positions within these lengths. * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use interleaved RoPE pattern.
**Returns:**
Query tensor with RoPE applied, same shape as input.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
:::note Note Unlike fused\_qk\_ragged\_rope which requires ragged inputs, this function works with padded batch inputs where sequences may have different actual lengths but are padded to a uniform shape. ::: ## `fused_qk_ragged_rope()` {#max.nn.legacy.kernels.fused_qk_ragged_rope} > max.nn.legacy.kernels.fused\_qk\_ragged\_rope(kv\_params, input, input\_row\_offsets, kv\_collection, freqs\_cis, layer\_idx, interleaved=True, position\_ids=None, mrope\_section=None) Computes fused query-key attention with rotary positional encodings and ragged inputs.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – \[batch\_size \* seq\_len, n\_heads, head\_dim] * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ragged tensor offsets indicating where each batch starts and ends * kv\_collection (PagedCacheValues) – KV cache collection * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – tensor of shape (max\_seq\_len \* 2, head\_dim) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for KV cache * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use interleaved RoPE pattern * position\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional ragged 2D array of position IDs. If None, defaults to cache\_length + token\_idx for each token. When num\_sections > 1, mrope\_section must be provided to indicate each section of the head\_dim to apply RoPE to. Shape: [num\_sections, total\_seq\_len] * mrope\_section ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) – Optional list of integers indicating the section of the head\_dim to * position\_ids. (apply RoPE to. Must be used in conjunction with)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input. If input is not of the same dtype as freqs\_cis, it will be cast to the dtype of freqs\_cis for the computation, and cast back to the original dtype after the computation is finished. When position\_ids and mrope\_section are provided, it replaces the default position calculation (cache\_length + token\_idx) with explicit position values. This is useful for 3D RoPE in models like Qwen2.5-VL that need custom position encoding. ## `fused_qkv_padded_matmul()` {#max.nn.legacy.kernels.fused_qkv_padded_matmul} > max.nn.legacy.kernels.fused\_qkv\_padded\_matmul(kv\_params, input, wqkv, kv\_collection, layer\_idx, valid\_lengths, n\_heads) Computes fused query, key, and value projections with padded input. This is for non-ragged (padded batch) inputs where sequences may have different actual lengths but are padded to a uniform shape.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters. * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor with shape \[batch\_size, seq\_len, hidden\_dim]. * wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight tensor for Q, K, V projections. * kv\_collection (PagedCacheValues) – Paged KV cache collection. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for cache lookup (must be uint32). * valid\_lengths ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch] containing the valid length for each sequence (must be uint32). K and V are only written to cache for positions within these lengths. * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of attention heads.
**Returns:**
Query projections tensor. K and V projections are written to cache.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `fused_qkv_ragged_matmul()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul} > max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, bias=None) Computes fused query, key, and value projections with ragged input. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `fused_qkv_ragged_matmul_quantized()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul_quantized} > max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul\_quantized(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, quantization\_config, perm\_idx=None, bias=None) Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization\_config must be provided. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * quantization\_config ([QuantizationConfig](../../graph/quantization.md#max.graph.quantization.QuantizationConfig)) * perm\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `fused_qkv_ragged_matmul_scaled_float4()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul_scaled_float4} > max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul\_scaled\_float4(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, input\_scale, weight\_scale, tensor\_sf, kv\_scales=None, sf\_vector\_size=16, \_output\_dim=None) Computes fused query, key, and value projections with scaled float4 input and weights.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters. * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape \[M=total\_seq\_len, K=hidden\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each batch in the input tensor with shape \[batch\_size + 1]. * wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight tensor with shape \[N=(num\_heads + 2 \* num\_kv\_heads) \* head\_dim, K=hidden\_dim]. * kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have dtype uint32. * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of attention heads. * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input scale tensor. Shape for blockwise scaling is 5D e.g., \[2, 3, 32, 4, 4]. * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight scale tensor. Shape for blockwise scaling is 5D e.g., \[2, 34, 32, 4, 4] * tensor\_sf ([float](https://docs.python.org/3/library/functions.html#float) | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer-wise scaling factor equal to weight\_scale\_2 \* input\_scale (pre-quantization, non-inverted). * kv\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – TBD, used in NVFP4 KV cache, see: * \_output\_dim ([int](https://docs.python.org/3/library/functions.html#int) | None) – Optional output dimension. If not provided, the output dimension will be \[n\_heads \* head\_dim]. * sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int))
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `fused_qkv_ragged_matmul_scaled_float8()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul_scaled_float8} > max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul\_scaled\_float8(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, input\_scale, weight\_scale, bias=None, float8\_config=None, \_output\_dim=None) Computes fused query, key, and value projections with scaled float8 input and weights.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters. * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape \[M=total\_seq\_len, K=hidden\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each batch in the input tensor with shape \[batch\_size + 1]. * wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight tensor with shape \[N=(num\_heads + 2 \* num\_kv\_heads) \* head\_dim, K=hidden\_dim]. * kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have dtype uint32. * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of attention heads. * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input scale tensor. Shape varies depending on the quantization config. * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight scale tensor. Shape varies depending on the quantization config. * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional bias vector concatenated as \[q, k, v]. * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) – Optional Float8Config object containing float8 quantization parameters. If not provided, the quantization config will be inferred from the input and weight scale shapes. * \_output\_dim ([int](https://docs.python.org/3/library/functions.html#int) | None) – Optional output dimension. If not provided, the output dimension will be \[n\_heads \* head\_dim].
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `grouped_dynamic_scaled_fp8_matmul()` {#max.nn.legacy.kernels.grouped_dynamic_scaled_fp8_matmul} > max.nn.legacy.kernels.grouped\_dynamic\_scaled\_fp8\_matmul(hidden\_states, weight, a\_scales, b\_scales, expert\_start\_indices, expert\_ids, expert\_usage\_stats\_host, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16, tokens\_padded\_per\_expert=False) Grouped blockwise scaled matmul used in MoE layer. Perform a grouped blockwise scaled matmul of two tensors with scaling factors. hidden\_states and expert\_start\_indices are used together to implement the ragged tensor.
**Parameters:**
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply. (2D tensor) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed. (3D tensor) * a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor. (2D tensor) * b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor. (3D tensor) * expert\_start\_indices ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – indicates where each group starts and ends in hidden\_states. * expert\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The id of the expert for each group in hidden\_states. * expert\_usage\_stats\_host ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The maximum number of tokens assigned to any expert, and the number of active experts. * input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) – The scaling granularity for the input tensor. * weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec)) – The scaling granularity for the weight tensor. * tokens\_padded\_per\_expert ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, It’s guaranteed that the number of tokens for each local expert will be padded, so that a\_scales is aligned to 16 bytes. This is needed by the optimized grouped matmul kernel. * out\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The result of the matmul operation.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `grouped_dynamic_scaled_nvfp4_matmul()` {#max.nn.legacy.kernels.grouped_dynamic_scaled_nvfp4_matmul} > max.nn.legacy.kernels.grouped\_dynamic\_scaled\_nvfp4\_matmul(hidden\_states, weight, a\_scales, b\_scales, expert\_start\_indices, a\_scale\_offsets, expert\_ids, expert\_scales, expert\_usage\_stats\_host, out\_type=bfloat16) Performs grouped NVFP4 matmul for MoE layers. Performs a grouped matmul with NVFP4 (4-bit) quantized inputs and weights. The inputs are packed as uint8 (2 NVFP4 values per byte) with float8\_e4m3fn scaling factors. NVFP4 uses fixed 1D block scaling with 16 elements per scale factor along the K dimension. `hidden_states` and `expert_start_indices` together implement the ragged tensor representation for variable-length expert inputs.
**Parameters:**
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input activations with shape `[total_tokens, K/2]` where K is the unpacked hidden dimension. Dtype must be uint8 (packed NVFP4). * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert weights with shape `[num_experts, N, K/2]`. Dtype must be uint8 (packed NVFP4). * a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Scaling factors for inputs with shape `[num_scale_rows, K_groups, 32, 4, 4]`. Dtype must be float8\_e4m3fn. * b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Scaling factors for weights with shape `[num_experts, N_groups, K_groups, 32, 4, 4]`. Dtype must be float8\_e4m3fn. * expert\_start\_indices ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indices indicating where each expert’s tokens start in `hidden_states`. * a\_scale\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The offsets of the input scale tiles for each expert. * expert\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert ID for each group. * expert\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Per-expert scaling factors with shape `[num_experts]`. Dtype must be float32. Multiplied with the matmul output in the epilogue. * expert\_usage\_stats\_host ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – A tensor containing \[max\_tokens\_per\_expert, num\_active\_experts]. * out\_type ([DType](../../dtype.md#max.dtype.DType)) – Output dtype. Defaults to bfloat16. * tokens\_padded\_per\_expert – If True, tokens per expert are padded for alignment. Defaults to False.
**Returns:**
The matmul result with shape `[total_tokens, N]` and dtype `out_type`.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `grouped_matmul_ragged()` {#max.nn.legacy.kernels.grouped_matmul_ragged} > max.nn.legacy.kernels.grouped\_matmul\_ragged(hidden\_states, weight, expert\_start\_indices, expert\_ids, expert\_usage\_stats\_host) Grouped matmul used in MoE layer. hidden\_states and expert\_start\_indices are used together to implement the ragged tensor. expert\_start\_indices indicates where each group starts and ends in hidden\_states expert\_ids is the id of the expert for each group in hidden\_states expert\_usage\_stats\_host is the maximum number of tokens assigned to any expert, and the number of active experts.
**Parameters:**
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_start\_indices ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_usage\_stats\_host ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `kv_cache_copy_pages_d2h()` {#max.nn.legacy.kernels.kv_cache_copy_pages_d2h} > max.nn.legacy.kernels.kv\_cache\_copy\_pages\_d2h(device\_kv\_collection, device\_page\_ids, host\_kv\_blocks, host\_page\_ids, layer\_idx, device\_ref) Copy KV cache pages from GPU to CPU for a single layer. Performs async GPU->CPU copy of specified pages for layer-wise KV cache offloading.
**Parameters:**
* device\_kv\_collection (PagedCacheValues) – Source KV cache on GPU. * device\_page\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Source page IDs to read from GPU. * host\_kv\_collection – Destination KV cache on CPU. * host\_page\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Destination page IDs to write to CPU. Must have same length as device\_page\_ids. * layer\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Which layer to copy. * device\_ref ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) – Device for the GPU context. * host\_kv\_blocks ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue))
**Return type:**
None
## `kv_cache_get_max_seq_len()` {#max.nn.legacy.kernels.kv_cache_get_max_seq_len} > max.nn.legacy.kernels.kv\_cache\_get\_max\_seq\_len(kv\_params, kv\_collection) This kernel returns the maximum sequence length.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * kv\_collection (PagedCacheValues)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `kv_cache_ragged_2m_iadd()` {#max.nn.legacy.kernels.kv_cache_ragged_2m_iadd} > max.nn.legacy.kernels.kv\_cache\_ragged\_2m\_iadd(kv\_params, a, kv\_collection, input\_row\_offsets, lora\_end\_idx, batch\_seq\_len, layer\_idx) In-place add to paged KV cache with interleaved K/V layout. Performs an in-place addition of new key-value projections to paged KV cache. The input tensor a uses a “2M” layout where keys and values are interleaved: rows \[0, m) contain keys and rows \[m, 2m) contain values, where m is the number of tokens.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache configuration parameters. Must have cache\_strategy set to PAGED and page\_size must be defined. * a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor with interleaved K/V data, shape (2\*m, hidden\_size) where m is the number of tokens. Rows \[0, m) are keys, rows \[m, 2m) are values. * kv\_collection (PagedCacheValues) – The paged KV cache collection containing cache blocks, cache lengths, lookup tables, and max lengths tensors. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ragged tensor offsets indicating where each batch starts and ends * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA token portion. Marks the boundary between LoRA sequences and base model sequences in the batch. * batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Total sequence length in the batch. Used for indexing into the value portion of a. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The transformer layer index to update in the KV cache.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a does not have rank 2. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input\_row\_offsets does not have rank 1. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If kv\_params.cache\_strategy is not PAGED. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If kv\_params.page\_size is None.
**Return type:**
None
## `kv_cache_ragged_radd()` {#max.nn.legacy.kernels.kv_cache_ragged_radd} > max.nn.legacy.kernels.kv\_cache\_ragged\_radd(kv\_params, a, kv\_collection, input\_row\_offsets, batch\_offset, layer\_idx) This function adds a tensor to a slice of the KVCache, sliced on the batch dimension. This expects that the requests which should be sliced out are contiguous and in the front of the tensor, and we’re only adding to the last requests in the batch.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The tensor to add to the KVCache. * kv\_collection (PagedCacheValues) – The KVCache collection to add to. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The offsets of the input tensor. * batch\_offset ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The batch to start applying the r-add to. * layer\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – The layer index to add to. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams))
**Return type:**
None
## `matmul_k_cache_ragged()` {#max.nn.legacy.kernels.matmul_k_cache_ragged} > max.nn.legacy.kernels.matmul\_k\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx) Computes key projections with ragged input. hidden\_states and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
None
## `matmul_k_cache_ragged_scaled_float8()` {#max.nn.legacy.kernels.matmul_k_cache_ragged_scaled_float8} > max.nn.legacy.kernels.matmul\_k\_cache\_ragged\_scaled\_float8(kv\_params, hidden\_states, input\_row\_offsets, weight, input\_scale, weight\_scale, kv\_collection, scales\_granularity\_mnk, layer\_idx) Computes key projections with ragged input with FP8 block scaling.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters. * hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape \[M=total\_seq\_len, K=hidden\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each batch in the input tensor with shape \[batch\_size + 1]. * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight tensor with shape \[N=num\_heads, K=hidden\_dim]. * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input scale tensor with shape \[ceildiv(K / BLOCK\_SIZE\_K), ceildiv(M / BLOCK\_SIZE\_M)]. * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight scale tensor with shape \[ceildiv(N / BLOCK\_SIZE\_N), ceildiv(K / BLOCK\_SIZE\_K)]. * kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache. * scales\_granularity\_mnk ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – tuple\[int, int, int] representing the scaling (BLOCK\_SIZE\_M, BLOCK\_SIZE\_N, BLOCK\_SIZE\_K). * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have dtype uint32.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel, or when the cache strategy is not supported.
**Return type:**
None
## `matmul_kv_cache_ragged()` {#max.nn.legacy.kernels.matmul_kv_cache_ragged} > max.nn.legacy.kernels.matmul\_kv\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx) Computes key and value projections with ragged input. hidden\_states and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
None
## `matmul_static_scaled_float8()` {#max.nn.legacy.kernels.matmul_static_scaled_float8} > max.nn.legacy.kernels.matmul\_static\_scaled\_float8(input, weight, input\_scale, weight\_scale)
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `merge_ragged_tensors()` {#max.nn.legacy.kernels.merge_ragged_tensors} > max.nn.legacy.kernels.merge\_ragged\_tensors(a, a\_row\_offsets, b, b\_row\_offsets) Merges two ragged tensors into a single ragged tensor. Both ragged tensors must have the same batch size (same number of row offsets). This function interleaves the rows from each tensor based on their row offsets.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first ragged tensor of shape \[total\_a\_rows, …]. * a\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The row offsets of the first ragged tensor,indicating where each batch starts and ends in a. * b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second ragged tensor of shape \[total\_b\_rows, …]. * b\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The row offsets of the second ragged tensor, indicating where each batch starts and ends in b.
**Returns:**
* The merged ragged tensor with shape \[total\_a\_rows + total\_b\_rows, …]. * The merged row offsets with the same shape as input row offsets.
**Return type:**
A tuple of two tensors
Example: ```python a = [1, 2, 3, 4, 5, 6] a_row_offsets = [0, 2, 6] b = [7, 8, 9, 10] b_row_offsets = [0, 3, 4] merged_tensor, merged_row_offsets = merge_ragged_tensors( a, a_row_offsets, b, b_row_offsets) merged_tensor = [1, 2, 7, 8, 9, 3, 4, 5, 6, 10] merged_row_offsets = [0, 5, 10] ``` ## `mla_decode_branch_fp8()` {#max.nn.legacy.kernels.mla_decode_branch_fp8} > max.nn.legacy.kernels.mla\_decode\_branch\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, w\_uk, w\_uk\_scale, w\_uv, w\_uv\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config) This is a manually fused kernel that performs the following operations: * Apply RoPE to the query and the key cache (in-place). * Apply RMSNorm to the non-rope portion of the key cache (in-place). * Project q\_nope to kv\_latent\_dim through a fp8 batched matmul: q\_nope\_proj = q\_nope\_t @ w\_uk * Concatenate q\_nope\_proj and q\_rope: q\_full = concat(q\_nope\_proj, q\_rope, axis=2) * Perform MLA decode * Project raw\_output to v\_head\_dim through another fp8 batched matmul: output = raw\_output\_t @ w\_uv
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor containing both nope and rope parts. Shape: \[tot\_seq\_len, num\_heads, qk\_nope\_head\_dim + qk\_rope\_head\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request starts and ends in input. This is a 1D tensor of shape \[num\_batches + 1]. * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max\_seq\_len, qk\_rope\_head\_dim]. * kv\_a\_proj\_layernorm ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv\_lora\_rank]. * w\_uk ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for projecting q\_nope to kv\_latent\_dim. Shape: \[num\_heads, kv\_latent\_dim, qk\_nope\_head\_dim]. * w\_uk\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies depending on the float8\_config. * w\_uv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for projecting MLA decode output to v\_head\_dim. Shape: [num\_heads, v\_head\_dim, kv\_latent\_dim]. * w\_uv\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies depending on the float8\_config. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams * kv\_collection (PagedCacheValues) – Paged KV Cache object. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index. * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant. * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale for the attention calculation. * epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant for numerical stability in RMSNorm. * v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension of the V heads. * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) – Float8Config for the weight matrix.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `mla_prefill_branch_fp8()` {#max.nn.legacy.kernels.mla_prefill_branch_fp8} > max.nn.legacy.kernels.mla\_prefill\_branch\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, kv\_b\_proj\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config) This is a manually fused kernel that performs the following operations: * Apply RoPE to the query and the key cache (in-place). * Apply RMSNorm to the non-rope portion of the key cache (in-place). * Copy the KV latent values from PagedKVCache to a contiguous buffer. * Quantize the KV latent values to fp8. * Up-project the latent KV values to full K and V through a matmul. * Split the concatenated KV into K and V. * Perform MLA prefill.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor containing both nope and rope parts. Shape: \[tot\_seq\_len, num\_heads, qk\_nope\_head\_dim + qk\_rope\_head\_dim]. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request starts and ends in input. This is a 1D tensor of shape \[num\_batches + 1]. * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max\_seq\_len, qk\_rope\_head\_dim]. * kv\_a\_proj\_layernorm ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv\_lora\_rank]. * buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request’s KV latent values should be stored in the contiguous buffer. This is a 1D tensor of shape \[num\_batches + 1]. * cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape \[num\_batches + 1]. * buffer\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The total number of tokens in the KV cache. Scalar. * kv\_b\_proj ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for up-projecting the KV latent values to full K and V. Shape: [num\_heads \* (qk\_nope\_head\_dim + v\_head\_dim), kv\_latent\_dim]. * kv\_b\_proj\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies depending on the float8\_config. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams * kv\_collection (PagedCacheValues) – Paged KV Cache object. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index. * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant. * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale for the attention calculation. * epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant for numerical stability in RMSNorm. * v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension of the V heads. * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) – Float8Config for the weight matrix.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `mla_prefill_decode_graph_bf16()` {#max.nn.legacy.kernels.mla_prefill_decode_graph_bf16} > max.nn.legacy.kernels.mla\_prefill\_decode\_graph\_bf16(q, input\_row\_offsets, freqs\_cis, kv\_norm\_gamma, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, w\_uk, w\_uv, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim) BF16 mega-kernel for MLA prefill/decode. Switches between prefill and decode based on the maximum sequence length in the batch.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor with nope+rope parts. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Row offsets for the batch. * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RoPE frequency tensor. * kv\_norm\_gamma ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma for KV cache. * buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * buffer\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * kv\_b\_proj ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – KV up-projection weights. * w\_uk ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode/output projection weights. * w\_uv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode/output projection weights. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters. * kv\_collection (PagedCacheValues) – Paged KV cache values. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index (uint32). * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Attention mask variant. * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Attention scale. * epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – RMSNorm epsilon. * v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Value head dimension.
**Returns:**
Output tensor of shape \[total\_seq\_len, num\_heads, v\_head\_dim].
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input ranks/dtypes or cache strategy are invalid.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `mla_prefill_decode_graph_fp8()` {#max.nn.legacy.kernels.mla_prefill_decode_graph_fp8} > max.nn.legacy.kernels.mla\_prefill\_decode\_graph\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, kv\_b\_proj\_scale, w\_uk, w\_uk\_scale, w\_uv, w\_uv\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config) Fused MLA prefill/decode kernel for FP8. Switches between prefill and decode based on the maximum sequence length in the batch. See mla\_prefill\_branch\_fp8 and mla\_decode\_branch\_fp8 for the dedicated paths.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor with nope+rope parts. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Row offsets for the batch. * freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RoPE frequency tensor. * kv\_a\_proj\_layernorm ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma for KV cache. * buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * buffer\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – One-shot prefill plan. * kv\_b\_proj ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – KV up-projection weights and scales. * kv\_b\_proj\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – KV up-projection weights and scales. * w\_uk ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode projection weights/scales. * w\_uk\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode projection weights/scales. * w\_uv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode projection weights/scales. * w\_uv\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Decode projection weights/scales. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters. * kv\_collection (PagedCacheValues) – Paged KV cache values. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index (uint32). * mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Attention mask variant. * scale ([float](https://docs.python.org/3/library/functions.html#float)) – Attention scale. * epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – RMSNorm epsilon. * v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Value head dimension. * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) – Float8 configuration used for scaling.
**Returns:**
Output tensor of shape \[total\_seq\_len, num\_heads, v\_head\_dim].
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input ranks/dtypes or cache strategy are invalid. * [AssertionError](https://docs.python.org/3/library/exceptions.html#AssertionError) – If float8 scale block sizes are not set.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `moe_create_indices()` {#max.nn.legacy.kernels.moe_create_indices} > max.nn.legacy.kernels.moe\_create\_indices(topk\_ids, num\_local\_experts) Creates indices for the MoE layer.
**Parameters:**
* topk\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert assignments for each token from the router. * num\_local\_experts ([int](https://docs.python.org/3/library/functions.html#int)) – The number of experts on this device.
**Returns:**
* token\_expert\_order: The reordered token indices, grouped by assigned expert. * expert\_start\_indices: The starting index for each expert’s token group in the reordered sequence. * restore\_token\_order: The indices to restore original token ordering after expert computation. * expert\_ids: ids of active experts selected for tokens * expert\_usage\_stats: The maximum number of tokens assigned to any expert, and the number of active experts.
**Return type:**
A tuple of five tensors
## `moe_router_group_limited()` {#max.nn.legacy.kernels.moe_router_group_limited} > max.nn.legacy.kernels.moe\_router\_group\_limited(expert\_scores, expert\_bias, n\_routed\_experts, n\_experts\_per\_tok, n\_groups, topk\_group, norm\_weights, routed\_scaling\_factor) Group limited MoE router. Reference: .
**Parameters:**
* expert\_scores ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scores for each expert for each token. Shape: \[num\_tokens, n\_routed\_experts]. * expert\_bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The bias for each expert. Shape: [n\_routed\_experts]. * n\_routed\_experts ([int](https://docs.python.org/3/library/functions.html#int)) – The total number of experts. Must be divisible by n\_groups. * n\_experts\_per\_tok ([int](https://docs.python.org/3/library/functions.html#int)) – The number of experts to be selected per token. * n\_groups ([int](https://docs.python.org/3/library/functions.html#int)) – The total number of expert groups. Must be divisible by n\_routed\_experts. * topk\_group ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum number of expert groups that a token will be routed to. * norm\_weights ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to normalize the selected expert weights. * routed\_scaling\_factor ([float](https://docs.python.org/3/library/functions.html#float)) – The scaling factor for the routed expert weights.
**Returns:**
* expert\_indices: The indices of the routed experts for each token. Shape: [num\_tokens, n\_experts\_per\_tok]. * expert\_weights: The weights of the routed experts for each token. Shape: [num\_tokens, n\_experts\_per\_tok].
**Return type:**
A tuple of two tensors
## `needs_fp8_fnuz_conversion()` {#max.nn.legacy.kernels.needs_fp8_fnuz_conversion} > max.nn.legacy.kernels.needs\_fp8\_fnuz\_conversion() Check if we need to convert FP8 E4M3FN to FNUZ for AMD GPUs.
**Returns:**
True if running on AMD GPU with CDNA3 architecture, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `normalize_e4m3fn_to_e4m3fnuz()` {#max.nn.legacy.kernels.normalize_e4m3fn_to_e4m3fnuz} > max.nn.legacy.kernels.normalize\_e4m3fn\_to\_e4m3fnuz(weight, weight\_scale) Convert E4M3FN weights to E4M3FNUZ format for AMD GPUs. This conversion is necessary because AMD GPUs use the E4M3FNUZ format while NVIDIA GPUs use E4M3FN. The key differences are: 1. The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ 2. For the same bit representation, E4M3FNUZ values are half of E4M3FN values
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight tensor in E4M3FN format. * weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight scale factor.
**Returns:**
Tuple of (converted\_weight, adjusted\_weight\_scale, adjusted\_input\_scale).
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `quantize_dynamic_block_scaled_fp4()` {#max.nn.legacy.kernels.quantize_dynamic_block_scaled_fp4} > max.nn.legacy.kernels.quantize\_dynamic\_block\_scaled\_fp4(input, tensor\_sf, sf\_vector\_size=16, scales\_type=float8\_e4m3fn, out\_type=uint8) Dynamically quantize the input tensor to fp4-e2m1fn.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize. Shape: [seq\_len, hidden\_size] * tensor\_sf ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [float](https://docs.python.org/3/library/functions.html#float)) – The tensor-wise scale factor (inverted as per quantization kernel requirement). * sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The block size for the scaling factors. * out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor. * scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor.
**Returns:**
The quantized tensor in \[seq\_len, hidden\_size // 2] layout and the scales in \[ceildiv(seq\_len, 128), ceildiv(hidden\_size, sf\_vector\_size \* 4), 32, 4, 4] layout.
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `quantize_dynamic_scaled_float8()` {#max.nn.legacy.kernels.quantize_dynamic_scaled_float8} > max.nn.legacy.kernels.quantize\_dynamic\_scaled\_float8(input, input\_scale\_spec, weight\_scale\_spec, scale\_ub=1200.0, group\_size\_or\_per\_token=-1, out\_type=float8\_e4m3fn, scales\_type=bfloat16) Dynamically quantize the input tensor to fp8.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize. * scale\_ub ([float](https://docs.python.org/3/library/functions.html#float)) – The upper bound of the scale factor. * group\_size\_or\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The group size for quantization. When set to -1, the quantization is column-wise. * out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor. * scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor. * input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) * weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
**Returns:**
The quantized tensor and the scales.
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `quantize_static_scaled_float8()` {#max.nn.legacy.kernels.quantize_static_scaled_float8} > max.nn.legacy.kernels.quantize\_static\_scaled\_float8(x, scale, scale\_is\_inverted=True, out\_type=float8\_e4m3fn)
**Parameters:**
* x ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * scale\_is\_inverted ([bool](https://docs.python.org/3/library/functions.html#bool)) * out\_type ([DType](../../dtype.md#max.dtype.DType))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `rms_norm_key_cache()` {#max.nn.legacy.kernels.rms_norm_key_cache} > max.nn.legacy.kernels.rms\_norm\_key\_cache(kv\_params, kv\_collection, gamma, epsilon, layer\_idx, total\_seq\_len, input\_row\_offsets, weight\_offset, rms\_norm\_cols=None, multiply\_before\_cast=True, per\_head\_norm=True) This function applies RMSNorm to the \_new\_ entries in the KVCache. When per\_head\_norm=True (default), RMSNorm is applied separately to each head. In this mode, gamma should have size \[head\_dim] and normalization occurs across the head\_dim dimensions within each head. When per\_head\_norm=False, RMSNorm is applied per token across all heads. In this mode, gamma should have size \[n\_kv\_heads \* head\_dim] and normalization occurs across all dimensions for each token. The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match the expected size based on per\_head\_norm setting, rms\_norm\_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions. Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input\_row\_offsets to do this bookkeeping.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * kv\_collection (PagedCacheValues) * gamma ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * epsilon ([float](https://docs.python.org/3/library/functions.html#float) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * total\_seq\_len ([Dim](../../graph/dim.md#max.graph.dim.Dim)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight\_offset ([float](https://docs.python.org/3/library/functions.html#float) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) * rms\_norm\_cols ([int](https://docs.python.org/3/library/functions.html#int) | None) * multiply\_before\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) * per\_head\_norm ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
None
## `scatter_nd_skip_oob_indices()` {#max.nn.legacy.kernels.scatter_nd_skip_oob_indices} > max.nn.legacy.kernels.scatter\_nd\_skip\_oob\_indices(input, updates, indices) Creates a new symbolic tensor where the updates are scattered into input at specified indices. This differs from scatter\_nd in that it handles oob indices by skipping the update for that index. Oob indices are those which fall outside of the range \[-dim, dim).
**Parameters:**
* input (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to. * updates (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input. * indices (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – A tensor of indices specifying where to write updates. Shape should be \[num\_updates, rank] for full indexing or \[num\_updates, k] for partial indexing where k < rank.
**Returns:**
A new symbolic tensor representing the result of the scatter\_nd operation.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `scatter_set_constant()` {#max.nn.legacy.kernels.scatter_set_constant} > max.nn.legacy.kernels.scatter\_set\_constant(data, indices, fill\_val) Scatters values into a tensor at specified indices.
**Parameters:**
* data ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue) | HasBufferValue) * indices (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) * fill\_val ([float](https://docs.python.org/3/library/functions.html#float))
**Return type:**
None
## `sgmv_kernel()` {#max.nn.legacy.kernels.sgmv_kernel} > max.nn.legacy.kernels.sgmv\_kernel(input, lora, lora\_ids, lora\_ranks, input\_row\_offsets, max\_lora\_seq\_len, lora\_end\_idx=None, bias=None) Performs the SGMV kernel for LoRA. This is LoRA agnostic, meaning that we can perform LoRA A or B from this kernel call. :param input: The input tensor :param lora: The LoRA tensor :param lora\_ids: Ids of the LoRAs used for each sequence :param lora\_ranks: The ranks of the LoRAs ihn the batch :param input\_row\_offsets: The sequence offsets that use LoRA :param max\_lora\_seq\_len: The maximum sequence length of any given LoRA in the batch :param bias: The LoRA bias
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
## `sgmv_lora_kernel()` {#max.nn.legacy.kernels.sgmv_lora_kernel} > max.nn.legacy.kernels.sgmv\_lora\_kernel(input, lora\_a, lora\_b, lora\_ids, lora\_ranks, grouped\_row\_offsets, lora\_end\_idx, max\_lora\_seq\_len, bias=None) Computes the SGMV LoRA kernel for some number of LoRAs A and B given the input. out = Wx + xAB SGMV can be explained by two independent kernels: : - shrink -> shrinks high-dimensional tensor to low-rank tensor * expand -> expands low-rank tensor to high-dimensional tensor where v = [0, …] and y = (some output tensor) SGMV-shrink: : v += xA SGMV-expand: : y += vB
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor * lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA tensor for A * lora\_b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA tensor for B * lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ids of the LoRAs used for each sequence * lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The ranks of the LoRAs ihn the batch * grouped\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The grouped sequence offsets that use LoRA * max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length of any given LoRA in the batch * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – The LoRA bias * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `sgmv_lora_qkv_shrink()` {#max.nn.legacy.kernels.sgmv_lora_qkv_shrink} > max.nn.legacy.kernels.sgmv\_lora\_qkv\_shrink(input, lora\_a, lora\_ids, lora\_grouped\_offsets, lora\_end\_idx, max\_lora\_seq\_len, max\_rank) LoRA shrink grouped matmul with planar Q/K/V output. Performs the LoRA ‘shrink’ operation for routed tokens using SGMV (segmented grouped matrix-vector multiplication). Computes \[M, K] @ \[G, 3\*rank, K]^T per active LoRA adapter, then permutes the flat \[M, 3\*rank] result into a planar layout \[3, M, rank] representing separate Q, K, V projections.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Routed activation matrix with shape (M, K), where M is the total number of tokens and K is the hidden dimension. * lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Shrink weights for all LoRA adapters, shape (G, 3\*rank, K) where G is the number of adapters and rank is the LoRA rank. * lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Expert/adapter indices for each active group, shape (num\_active,). Values in range \[0, G). May use -1 to indicate inactive slots. * lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Inclusive prefix sums of tokens per active adapter, shape (num\_active + 1,). Defines per-adapter \[start, end) ranges in input. Must be non-decreasing with offsets\[0] == 0. * max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – Upper bound on tokens for any active adapter. Used for kernel tuning and memory allocation. * max\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum LoRA rank, determines output shape. * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Returns:**
Output tensor with planar Q/K/V layout, shape (3, M, max\_rank).
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `sgmv_qkv_lora_kernel()` {#max.nn.legacy.kernels.sgmv_qkv_lora_kernel} > max.nn.legacy.kernels.sgmv\_qkv\_lora\_kernel(input, lora\_a, lora\_b\_q, lora\_b\_kv, lora\_ids, lora\_ranks, input\_row\_offsets, lora\_grouped\_offsets, lora\_end\_idx, batch\_seq\_len, lora\_ids\_kv, lora\_grouped\_offsets\_kv, kv\_collection, kv\_params, layer\_idx, max\_lora\_seq\_len, max\_rank, bias=None) Computes the SGMV QKV LoRA kernel for Q, K, V projections with LoRA.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor. * lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA A tensor. * lora\_b\_q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA B tensor for Q projection. * lora\_b\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA B tensor for K and V projections (stacked). * lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – IDs of the LoRAs used for each sequence. * lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The ranks of the LoRAs in the batch. * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The sequence offsets that use LoRA. * lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Grouped offsets for LoRA sequences. * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA tokens in the batch. * batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Total sequence length of the batch. * lora\_ids\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – LoRA IDs for KV projections (with offset for V portion). * lora\_grouped\_offsets\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Grouped offsets for KV LoRA sequences. * kv\_collection (PagedCacheValues) – The KV cache. * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – The KV params. * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The layer index to retrieve the KV cache. * max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length of any given LoRA in the batch. * max\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum rank for the LoRAs. * bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional LoRA bias.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `sleep()` {#max.nn.legacy.kernels.sleep} > max.nn.legacy.kernels.sleep(duration\_sec, device\_ref) Sleep for the given duration in seconds. This kernel is supported on CPUs and GPUs. However, the timing may be completely inaccurate on AMD GPUs due to limitation of current time.sleep(…) impl.
**Parameters:**
* duration\_sec ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – The duration to sleep in seconds. * device\_ref ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
None
## `sliced_add()` {#max.nn.legacy.kernels.sliced_add} > max.nn.legacy.kernels.sliced\_add(x, y, lora\_end\_idx) Adds tensors x and y element-wise for rows < lora\_end\_idx, otherwise copies x. This is used for LoRA where only some sequences have LoRA applied. For rows in \[0, lora\_end\_idx): c = x + y For rows in \[lora\_end\_idx, batch\_seq\_len): c = x
**Parameters:**
* x ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – First input tensor. * y ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Second input tensor. * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA token portion (rows to apply add).
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `spatial_merge()` {#max.nn.legacy.kernels.spatial_merge} > max.nn.legacy.kernels.spatial\_merge(input, grid\_thw, hidden\_size, merge\_size) Performs spatial merge operation on ragged input tensors. This operation merges spatial dimensions of input patches according to the grid dimensions specified in grid\_thw.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor of shape \[total\_patches\_in\_grid, hidden\_size] * grid\_thw ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Grid dimensions tensor of shape \[batch\_size, 3] containing \[t, h, w] for each batch item, where: * t: temporal/frame dimension * h: height dimension * w: width dimension * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Hidden dimension size * merge\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Size of spatial merge blocks (typically 2)
**Returns:**
Output tensor of shape \[total\_patches\_in\_grid, hidden\_size]
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `swish_glu()` {#max.nn.legacy.kernels.swish_glu} > max.nn.legacy.kernels.swish\_glu(a, b0, b1) Computes swish(.t()) \* (.t())
**Parameters:**
* a (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) * b0 (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) * b1 (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `topk_fused_sampling()` {#max.nn.legacy.kernels.topk_fused_sampling} > max.nn.legacy.kernels.topk\_fused\_sampling(logits, top\_k, \*, temperature=1.0, max\_k=None, min\_top\_p=None, top\_p=1.0, seed=0) Performs top-k sampling with temperature scaling.
**Parameters:**
* logits ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input logits tensor of shape \[batch\_size, vocab\_size]. * top\_k (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – Number of top tokens to consider for sampling. Can be a scalar (which will be expanded to batch\_size) or a tensor of shape \[batch\_size]. * temperature (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – Temperature for scaling logits before sampling. * max\_k (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) – Maximum value of k across the batch. Required when top\_k is a tensor. * top\_p (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – Top-p (nucleus) sampling threshold. Can be a scalar or tensor. * seed (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – Seed for the random number generator. Can be a scalar or tensor. * min\_top\_p (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None)
**Returns:**
Sampled tokens tensor of shape \[batch\_size, 1].
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input validation fails.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `unfused_qkv_ragged_matmul_gguf_quantized()` {#max.nn.legacy.kernels.unfused_qkv_ragged_matmul_gguf_quantized} > max.nn.legacy.kernels.unfused\_qkv\_ragged\_matmul\_gguf\_quantized(kv\_params, input, input\_row\_offsets, n\_heads, q\_weight, k\_weight, v\_weight, quantization\_encoding\_q, quantization\_encoding\_k, quantization\_encoding\_v, kv\_collection, layer\_idx) Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization\_config must be provided. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * q\_weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * k\_weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * v\_weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * quantization\_encoding\_q ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding)) * quantization\_encoding\_k ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding)) * quantization\_encoding\_v ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding)) * kv\_collection (PagedCacheValues) * layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `update_frequency_data()` {#max.nn.legacy.kernels.update_frequency_data} > max.nn.legacy.kernels.update\_frequency\_data(frequency\_data, frequency\_offsets, tokens) Updates the frequency data.
**Parameters:**
* frequency\_data ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – 2d tensor of shape \[unique\_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token. * frequency\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 1d tensor of shape \[batch\_size + 1], indicating start of each sequence’s data. * tokens ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The tokens to update the frequency data with.
**Return type:**
None
--- ## cache_params ## `KVCacheParamInterface` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface} > class max.nn.legacy.kv\_cache.cache\_params.KVCacheParamInterface(\*args, \*\*kwargs) Interface for KV cache parameters. ### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.bytes_per_block} > property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) Number of bytes per cache block. ### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.cache_strategy} > cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy) ### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.data_parallel_degree} > data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) ### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.get_symbolic_inputs} > get\_symbolic\_inputs() Returns the symbolic inputs for the KV cache.
**Return type:**
InputSymbolInterface
### `n_devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.n_devices} > n\_devices: [int](https://docs.python.org/3/library/functions.html#int) ### `page_size` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.page_size} > page\_size: [int](https://docs.python.org/3/library/functions.html#int) ## `KVCacheParams` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams} > class max.nn.legacy.kv\_cache.cache\_params.KVCacheParams(dtype, n\_kv\_heads, head\_dim, num\_layers, devices, enable\_prefix\_caching=False, enable\_kvcache\_swapping\_to\_host=False, host\_kvcache\_swap\_space\_gb=None, cache\_strategy=KVCacheStrategy.PAGED, page\_size=128, is\_mla=False, data\_parallel\_degree=1, n\_kv\_heads\_per\_device=0, kvcache\_quant\_config=None) Configuration parameters for key-value cache management in transformer models. This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.
**Parameters:**
* dtype ([DType](../../../dtype.md#max.dtype.DType)) * n\_kv\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * num\_layers ([int](https://docs.python.org/3/library/functions.html#int)) * devices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)]) * enable\_prefix\_caching ([bool](https://docs.python.org/3/library/functions.html#bool)) * enable\_kvcache\_swapping\_to\_host ([bool](https://docs.python.org/3/library/functions.html#bool)) * host\_kvcache\_swap\_space\_gb ([float](https://docs.python.org/3/library/functions.html#float) | None) * cache\_strategy ([KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)) * page\_size ([int](https://docs.python.org/3/library/functions.html#int)) * is\_mla ([bool](https://docs.python.org/3/library/functions.html#bool)) * data\_parallel\_degree ([int](https://docs.python.org/3/library/functions.html#int)) * n\_kv\_heads\_per\_device ([int](https://docs.python.org/3/library/functions.html#int)) * kvcache\_quant\_config ([KVCacheQuantizationConfig](#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig) | None)
### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.bytes_per_block} > property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) Returns the number of bytes per cache block. When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.
**Returns:**
The number of bytes per cache block.
### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.cache_strategy} > cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy) = 'paged' Strategy to use for managing the KV cache. ### `compute_num_host_blocks()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.compute_num_host_blocks} > compute\_num\_host\_blocks() Computes the number of blocks that can be allocated to the host.
**Returns:**
The number of blocks that can be allocated to the host.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `copy_as_dp_1()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.copy_as_dp_1} > copy\_as\_dp\_1() Creates a copy of the KVCacheParams with data parallelism disabled. This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data\_parallel\_degree=1). The number of devices is divided by the current data parallel degree.
**Returns:**
A new KVCacheParams instance with data\_parallel\_degree set to 1.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n\_devices is not evenly divisible by data\_parallel\_degree.
**Return type:**
[KVCacheParams](#max.nn.legacy.kv_cache.cache_params.KVCacheParams)
### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.data_parallel_degree} > data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) = 1 Degree of data parallelism. Must be 1 or equal to n\_devices (DP+TP not yet supported). ### `devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.devices} > devices: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)] Devices to use for the KV cache. ### `dtype` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.dtype} > dtype: [DType](../../../dtype.md#max.dtype.DType) Data type for storing key and value tensors in the cache. ### `dtype_shorthand` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.dtype_shorthand} > property dtype\_shorthand: [str](https://docs.python.org/3/library/stdtypes.html#str) Returns a shorthand textual representation of the data type.
**Returns:**
“bf16” for bfloat16 dtype, “f32” otherwise.
### `enable_kvcache_swapping_to_host` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.enable_kvcache_swapping_to_host} > enable\_kvcache\_swapping\_to\_host: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether to enable swapping of KV cache blocks to host memory when device memory is full. ### `enable_prefix_caching` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.enable_prefix_caching} > enable\_prefix\_caching: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether to enable prefix caching for efficient reuse of common prompt prefixes. ### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.get_symbolic_inputs} > get\_symbolic\_inputs() Computes the symbolic inputs for the KV cache. This method returns a list of PagedCacheInputSymbols for each replica. This is used when constructing the model graph.
**Returns:**
The symbolic inputs for the KV cache.
**Return type:**
PagedCacheInputSymbolsByReplica
### `head_dim` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.head_dim} > head\_dim: [int](https://docs.python.org/3/library/functions.html#int) Dimensionality of each attention head. ### `host_kvcache_swap_space_gb` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.host_kvcache_swap_space_gb} > host\_kvcache\_swap\_space\_gb: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled. ### `is_mla` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.is_mla} > is\_mla: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the model uses Multi-Latent Attention (MLA) architecture. ### `kvcache_quant_config` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.kvcache_quant_config} > kvcache\_quant\_config: [KVCacheQuantizationConfig](#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig) | [None](https://docs.python.org/3/library/constants.html#None) = None KVCache quantization config. Currently only FP8 quantization supported. ### `n_devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_devices} > property n\_devices: [int](https://docs.python.org/3/library/functions.html#int) Returns the number of devices.
**Returns:**
The number of devices.
### `n_kv_heads` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_kv_heads} > n\_kv\_heads: [int](https://docs.python.org/3/library/functions.html#int) Total number of key-value attention heads across all devices. ### `n_kv_heads_per_device` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_kv_heads_per_device} > n\_kv\_heads\_per\_device: [int](https://docs.python.org/3/library/functions.html#int) = 0 Number of KV heads allocated to each device. Computed automatically in \_\_post\_init\_\_. ### `num_layers` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.num_layers} > num\_layers: [int](https://docs.python.org/3/library/functions.html#int) Number of layers in the model. ### `page_size` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.page_size} > page\_size: [int](https://docs.python.org/3/library/functions.html#int) = 128 Number of tokens per page (block) when using the paged cache strategy. This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration. Current constraints: the page size must be a multiple of 128 and at least 128. Required when `cache_strategy` is `KVCacheStrategy.PAGED`. ### `quantized_kv_cache` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.quantized_kv_cache} > property quantized\_kv\_cache: [bool](https://docs.python.org/3/library/functions.html#bool) ### `shape_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.shape_per_block} > property shape\_per\_block: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Returns the shape of each cache block.
**Returns:**
The shape of the cache block.
### `shape_per_scale_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.shape_per_scale_block} > property shape\_per\_scale\_block: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Returns the shape of each scale block used for KVCache quantization
**Returns:**
The shape of the KVCache quantization scales block.
### `tensor_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.tensor_parallel_degree} > property tensor\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) Returns the tensor parallel degree.
**Returns:**
The tensor parallel degree.
## `KVCacheQuantizationConfig` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig} > class max.nn.legacy.kv\_cache.cache\_params.KVCacheQuantizationConfig(scale\_dtype=float32, quantization\_granularity=128) Configuration for KVCache quantization. Currently only FP8 Quantization is supported.
**Parameters:**
* scale\_dtype ([DType](../../../dtype.md#max.dtype.DType)) * quantization\_granularity ([int](https://docs.python.org/3/library/functions.html#int))
### `quantization_granularity` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig.quantization_granularity} > quantization\_granularity: [int](https://docs.python.org/3/library/functions.html#int) = 128 Block-size used for KVCache quantization along head-dimension (e.g. 128). ### `scale_dtype` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig.scale_dtype} > scale\_dtype: [DType](../../../dtype.md#max.dtype.DType) = 81 Data type of quantization scales, if quantization is enabled ## `KVCacheStrategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy} > class max.nn.legacy.kv\_cache.cache\_params.KVCacheStrategy(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Enumeration of supported KV cache strategies for attention mechanisms. This enum defines the different strategies for managing key-value caches in transformer models during inference. ### `MODEL_DEFAULT` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.MODEL_DEFAULT} > MODEL\_DEFAULT = 'model\_default' Use the model’s default caching strategy. ### `PAGED` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.PAGED} > PAGED = 'paged' Use paged attention for efficient memory management. ### `kernel_substring()` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.kernel_substring} > kernel\_substring() Returns the common substring included in the kernel name for this caching strategy.
**Returns:**
The string representation of the cache strategy value.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `uses_opaque()` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.uses_opaque} > uses\_opaque() Determines if this cache strategy uses opaque cache implementations.
**Returns:**
True if the strategy uses opaque caching, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `MultiKVCacheParams` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams} > class max.nn.legacy.kv\_cache.cache\_params.MultiKVCacheParams(params, cache\_strategy, page\_size, data\_parallel\_degree, n\_devices) Aggregates multiple KV cache parameter sets. This class implements KVCacheParamInterface by aggregating multiple KVCacheParamInterface instances. Useful for models with multiple distinct KV caches (e.g., different cache configurations for different layers).
**Parameters:**
* params ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)]) * cache\_strategy ([KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)) * page\_size ([int](https://docs.python.org/3/library/functions.html#int)) * data\_parallel\_degree ([int](https://docs.python.org/3/library/functions.html#int)) * n\_devices ([int](https://docs.python.org/3/library/functions.html#int))
### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.bytes_per_block} > property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) Total bytes per block across all KV caches. Since all caches allocate memory for the same sequence, the total memory cost per block is the sum across all param sets. ### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.cache_strategy} > cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy) ### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.data_parallel_degree} > data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) ### `from_params()` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.from_params} > classmethod from\_params(\*params)
**Parameters:**
params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Return type:**
[MultiKVCacheParams](#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams)
### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.get_symbolic_inputs} > get\_symbolic\_inputs() Returns the symbolic inputs for the KV cache.
**Return type:**
MultiKVCacheInputSymbols
### `n_devices` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.n_devices} > n\_devices: [int](https://docs.python.org/3/library/functions.html#int) ### `page_size` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.page_size} > page\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `params` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.params} > params: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)] List of KV cache parameter sets to aggregate. ## `compute_max_seq_len_fitting_in_cache()` {#max.nn.legacy.kv_cache.cache_params.compute_max_seq_len_fitting_in_cache} > max.nn.legacy.kv\_cache.cache\_params.compute\_max\_seq\_len\_fitting\_in\_cache(params, available\_cache\_memory) Computes the maximum sequence length that can fit in the available memory.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across * devices. (all) * params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The maximum sequence length that can fit in the available cache memory.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `compute_num_device_blocks()` {#max.nn.legacy.kv_cache.cache_params.compute_num_device_blocks} > max.nn.legacy.kv\_cache.cache\_params.compute\_num\_device\_blocks(params, available\_cache\_memory, max\_batch\_size, max\_seq\_len) Computes the number of blocks that can be allocated based on the available cache memory. The number of blocks returned is for a single replica. Each replica will have the same number of blocks.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across all devices. * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum batch size, or None. * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum sequence length, or None. * params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The number of blocks that can be allocated for a single replica.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `estimated_memory_size()` {#max.nn.legacy.kv_cache.cache_params.estimated_memory_size} > max.nn.legacy.kv\_cache.cache\_params.estimated\_memory\_size(params, available\_cache\_memory, max\_batch\_size, max\_seq\_len) Computes the estimated memory size of the KV cache used by all replicas.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across all devices. * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum batch size. * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length. * params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The estimated memory usage of the KV cache in bytes.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
--- ## kv_cache (Kv_cache) Legacy key-value cache management for efficient attention computation. ## Modules * [`cache_params`](/max/api/python/nn/legacy/kv_cache/cache_params): Configuration parameters for KV cache. * [`manager`](/max/api/python/nn/legacy/kv_cache/manager): KV cache manager implementation. --- ## manager --- ## layer ## `Layer` {#max.nn.legacy.layer.Layer} > class max.nn.legacy.layer.Layer :::caution Deprecated Deprecated since version 25.2.. ::: Base class for neural network components. Use [`Module`](#max.nn.legacy.layer.Module) instead. Provides functionality for adding hooks to the call function of each layer to support testing, debugging or profiling. ## `LayerList` {#max.nn.legacy.layer.LayerList} > class max.nn.legacy.layer.LayerList(layers) Stores a list of layers. Can be used as a regular python list.
**Parameters:**
layers ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Layer](#max.nn.legacy.layer.Layer)])
### `append()` {#max.nn.legacy.layer.LayerList.append} > append(layer)
**Parameters:**
layer ([Layer](#max.nn.legacy.layer.Layer))
**Return type:**
None
### `extend()` {#max.nn.legacy.layer.LayerList.extend} > extend(layer)
**Parameters:**
layer ([Layer](#max.nn.legacy.layer.Layer))
**Return type:**
None
### `insert()` {#max.nn.legacy.layer.LayerList.insert} > insert(i, layer)
**Parameters:**
* i ([int](https://docs.python.org/3/library/functions.html#int)) * layer ([Layer](#max.nn.legacy.layer.Layer))
**Return type:**
None
### `sublayers` {#max.nn.legacy.layer.LayerList.sublayers} > property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.legacy.layer.Module)] ## `Module` {#max.nn.legacy.layer.Module} > class max.nn.legacy.layer.Module Base class for model components with weight management. Provides functionality to create custom layers and construct networks with automatic weight tracking. The following example uses the [`Module`](#max.nn.legacy.layer.Module) class to create custom layers and build a neural network: ```python from max import nn from max.dtype import DType from max.graph import Weight, ops, DeviceRef class Linear(nn.Module): def __init__(self, in_dims, out_dims): super().__init__() self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU()) def __call__(self, x): return x @ self.weight.T class MLP(nn.Module): def __init__(self): self.up = Linear(5, 10) self.gate = Linear(5, 10) self.down = Linear(10, 5) def __call__(self, x): return self.down(ops.silu(self.gate(x)) + self.up(x)) model = MLP() print(model.state_dict()) # {"up.weight": Buffer([5, 10]), ...} ``` Constructing a graph without [`Module`](#max.nn.legacy.layer.Module) can result in name collisions with the weights (in this example, there would be three weights with the name Weight). With [`Module`](#max.nn.legacy.layer.Module), you can use [`state_dict()`](#max.nn.legacy.layer.Module.state_dict) or [`load_state_dict()`](#max.nn.legacy.layer.Module.load_state_dict) to initialize or set the weights values, and finalize the weight names to be unique within the model. ### `build_subgraph()` {#max.nn.legacy.layer.Module.build_subgraph} > build\_subgraph(name, input\_types, weight\_prefix='') Builds a subgraph for this module. This method creates a subgraph that encapsulates the module’s logic, handling input types, weights, and creating a graph with the module’s computation. Once the subgraph is built, it can be called using the `ops.call` op.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the subgraph to create. * input\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](../../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Type](../../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – A list of input types for the subgraph. Each element can be either a single `Type` or a list of `Type` objects. * weight\_prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Optional prefix for weight names in the subgraph. If provided, weights with names starting with this prefix will have their names modified by removing the prefix and will be marked as placeholders.
**Returns:**
The created subgraph containing the module’s computation.
**Return type:**
`Graph`
:::note Note * Placeholder weights will require the `prefix` attribute of `ops.call` to be set. ::: ### `layer_weights` {#max.nn.legacy.layer.Module.layer_weights} > property layer\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Weight](../../graph/Weight.md#max.graph.Weight)] ### `load_state_dict()` {#max.nn.legacy.layer.Module.load_state_dict} > load\_state\_dict(state\_dict, \*, override\_quantization\_encoding=False, weight\_alignment=None, strict=True) Sets the values of all weights in this model.
**Parameters:**
* state\_dict ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../../driver.md#max.driver.DLPackArray) | [WeightData](../../graph/weights.md#max.graph.weights.WeightData)]) – A map from weight name to a numpy array or [`max.driver.Buffer`](../../driver.md#max.driver.Buffer). * override\_quantization\_encoding ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to override the weight quantization based on the loaded value. * weight\_alignment ([int](https://docs.python.org/3/library/functions.html#int) | None) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state\_dict must be aligned by the default dtype alignment. * strict ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, raises an error if any weights required by the Module are missing from state\_dict, or if any keys in state\_dict were not used by the Module. If False, both missing and unexpected keys are tolerated and reported only via return values/logging by callers.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If strict is True and any required weight is missing from state\_dict, or if state\_dict contains keys not used by the Module.
**Return type:**
None
### `raw_state_dict()` {#max.nn.legacy.layer.Module.raw_state_dict} > raw\_state\_dict() Returns all weights objects in the model. Unlike [`state_dict`](#max.nn.legacy.layer.Module.state_dict), this returns [`max.graph.Weight`](../../graph/Weight.md#max.graph.Weight) objects instead of the assigned values. Some parameters inside the `Weight` can be configured before a graph is built. Do not change these attributes after building a graph: * [`align`](../../graph/Weight.md#max.graph.Weight.align) * [`dtype`](../../graph/Weight.md#max.graph.Weight.dtype) * [`quantization_encoding`](../../graph/Weight.md#max.graph.Weight.quantization_encoding) * [`shape`](../../graph/Weight.md#max.graph.Weight.shape)
**Returns:**
Map from weight name to the [`max.graph.Weight`](../../graph/Weight.md#max.graph.Weight) object.
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Weight](../../graph/Weight.md#max.graph.Weight)]
### `set_shared_weight()` {#max.nn.legacy.layer.Module.set_shared_weight} > set\_shared\_weight(name, weight)
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * weight ([Weight](../../graph/Weight.md#max.graph.Weight))
**Return type:**
None
### `state_dict()` {#max.nn.legacy.layer.Module.state_dict} > state\_dict(auto\_initialize=True) Returns values of all weights in the model. The values returned are the same as the values set in [`load_state_dict`](#max.nn.legacy.layer.Module.load_state_dict). If [`load_state_dict`](#max.nn.legacy.layer.Module.load_state_dict) has not been called and none of the weights have values, then they are initialized to zero.
**Parameters:**
auto\_initialize ([bool](https://docs.python.org/3/library/functions.html#bool)) – Determines whether to initialize weights to zero if the weight value has not been loaded. If this is False, a ValueError is raised if an uninitialized weight is found.
**Returns:**
Map from weight name to the weight value (can be numpy array or [`max.driver.Buffer`](../../driver.md#max.driver.Buffer)).
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../../driver.md#max.driver.DLPackArray)]
### `sublayers` {#max.nn.legacy.layer.Module.sublayers} > property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.legacy.layer.Module)] ## `Shardable` {#max.nn.legacy.layer.Shardable} > class max.nn.legacy.layer.Shardable(\*args, \*\*kwargs) Protocol for objects that support sharding across multiple devices. This protocol defines the interface that all shardable components (like Linear layers and Weight objects) must implement to participate in distributed computation. ### `shard()` {#max.nn.legacy.layer.Shardable.shard} > shard(devices) Creates a sharded view of this object for a specific device.
**Parameters:**
* device – The devices where this shard should reside. * devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)])
**Returns:**
A sequence of sharded instances of this object.
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Self](https://docs.python.org/3/library/typing.html#typing.Self)]
### `sharding_strategy` {#max.nn.legacy.layer.Shardable.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Gets the weight sharding strategy. ## `add_layer_hook()` {#max.nn.legacy.layer.add_layer_hook} > max.nn.legacy.layer.add\_layer\_hook(fn) Adds a hook to call a function after each layer’s `__call__`. The function will be passed four inputs: * layer * input\_args * input\_kwargs * outputs The function can either return None or new outputs that will replace the layer returned outputs. Note that input and outputs contain graph Values, which show limited information (like [`shape`](../../graph/TensorValue.md#max.graph.TensorValue.shape) and [`dtype`](../../graph/TensorValue.md#max.graph.TensorValue.dtype)). You can still see the computed values if you include the Value in the `graph.ops.output` op, or call `graph.ops.print`. Example of printing debug inputs: ```python def print_info(layer, args, kwargs, outputs): print("Layer:", type(layer).__name__) print("Input args:", args) print("Input kwargs:", kwargs) print("Outputs:", outputs) return outputs add_layer_hook(print_info) ```
**Parameters:**
fn ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[\[[Layer](#max.nn.legacy.layer.Layer), [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)], [Any](https://docs.python.org/3/library/typing.html#typing.Any)], [Any](https://docs.python.org/3/library/typing.html#typing.Any)])
**Return type:**
None
## `clear_hooks()` {#max.nn.legacy.layer.clear_hooks} > max.nn.legacy.layer.clear\_hooks() Remove all hooks.
**Return type:**
None
## `recursive_named_layers()` {#max.nn.legacy.layer.recursive_named_layers} > max.nn.legacy.layer.recursive\_named\_layers(parent, prefix='') Recursively walks through the layers and generates names.
**Parameters:**
* parent ([Module](#max.nn.legacy.layer.Module)) * prefix ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.legacy.layer.Module)]]
--- ## linear (Legacy) Multi-layer Perceptron. ## `ColumnParallelLinear` {#max.nn.legacy.linear.ColumnParallelLinear} > class max.nn.legacy.linear.ColumnParallelLinear(in\_dim, out\_dim, dtype, devices, tied\_weight=None, \*\*kwargs) A [`Linear`](#max.nn.legacy.linear.Linear) layer where the weight and bias are sharded onto multiple devices. This layer first computes $y = xW_i^T + b_i$ for each device i in \[0,…, num\_devices]: ```default +-----+ +-----+ T +-----+ +-----+ | | | W_0 | | b_0 | | y_0 | GPU0 | | +-----+ +-----+ +-----+ | | | W_1 | | b_1 | | y_1 | GPU1 | x | @ +-----+ + +-----+ = +-----+ | | | W_2 | | b_2 | | y_2 | GPU2 | | +-----+ +-----+ +-----+ | | | W_3 | | b_3 | | y_3 | GPU3 +-----+ +-----+ +-----+ +-----+ ``` The values are then collected using an Allgather op, producing the same output tensor $y = xW^T + b$ on each device: ```default GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3 +-----+-----+-----+-----+ +-----+-----+-----+-----+ | y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ | - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 | +-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+ | - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ | - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ ``` Example usage: ```python from max.dtype import DType from max.graph import DeviceRef from max.nn import ColumnParallelLinear num_devices = 4 distributed_linear = ColumnParallelLinear( in_dim, out_dim, DType.float32, devices=[DeviceRef.GPU(i) for i in range(num_devices)], ) ```
**Parameters:**
* in\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * devices (Sequence\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * tied\_weight ([Weight](../../graph/Weight.md#max.graph.Weight) | None)
## `DistributedGemmConfig` {#max.nn.legacy.linear.DistributedGemmConfig} > class max.nn.legacy.linear.DistributedGemmConfig(enable\_matmul\_allreduce) Configure how distributed GEMM is executed. Configuration for distributed General Matrix Multiply operations.
**Parameters:**
enable\_matmul\_allreduce ([bool](https://docs.python.org/3/library/functions.html#bool))
### `enable_matmul_allreduce` {#max.nn.legacy.linear.DistributedGemmConfig.enable_matmul_allreduce} > enable\_matmul\_allreduce: [bool](https://docs.python.org/3/library/functions.html#bool) If `True`, use the matmul + all\_reduce kernel. ### `generate()` {#max.nn.legacy.linear.DistributedGemmConfig.generate} > static generate() Returns the default [`DistributedGemmConfig`](#max.nn.legacy.linear.DistributedGemmConfig).
**Returns:**
A [`DistributedGemmConfig`](#max.nn.legacy.linear.DistributedGemmConfig) instance with default settings.
**Return type:**
[DistributedGemmConfig](#max.nn.legacy.linear.DistributedGemmConfig) | None
## `GPTQLinear` {#max.nn.legacy.linear.GPTQLinear} > class max.nn.legacy.linear.GPTQLinear(in\_dim, out\_dim, dtype, device, has\_bias=False, quantization\_encoding=None, quantization\_config=None, float8\_config=None) A [`Linear`](#max.nn.legacy.linear.Linear) layer for GPTQ encoding.
**Parameters:**
* in\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None) * quantization\_config ([QuantizationConfig](../../graph/quantization.md#max.graph.quantization.QuantizationConfig) | None) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None)
## `Linear` {#max.nn.legacy.linear.Linear} > class max.nn.legacy.linear.Linear(in\_dim, out\_dim, dtype, device, has\_bias=False, quantization\_encoding=None, float8\_config=None, name=None, clip\_weight=None, is\_sharding=False) Applies a linear transformation to incoming data: $y = xW^T + b$. This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device. Example: ```python linear_layer = Linear( in_dim=256, out_dim=128, dtype=DType.float32, device=DeviceRef.GPU(), name="linear", has_bias=True ) input_tensor: TensorValue output = linear_layer(input_tensor) ```
**Parameters:**
* in\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * clip\_weight ([float](https://docs.python.org/3/library/functions.html#float) | None) * is\_sharding ([bool](https://docs.python.org/3/library/functions.html#bool))
### `bias` {#max.nn.legacy.linear.Linear.bias} > bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_dim,). Model init moves the bias to the target device if present. ### `device` {#max.nn.legacy.linear.Linear.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) The device where matrix operations are performed. ### `input_scale` {#max.nn.legacy.linear.Linear.input_scale} > input\_scale: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional input scale stored on CPU with shape (). Model init moves the input\_scale to the target device if present. ### `shard()` {#max.nn.legacy.linear.Linear.shard} > shard(devices) Creates sharded views of this Linear layer across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of `DeviceRef` devices to place the shards on.
**Returns:**
List of sharded [`Linear`](#max.nn.legacy.linear.Linear) instances, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Linear](#max.nn.legacy.linear.Linear)]
### `sharding_strategy` {#max.nn.legacy.linear.Linear.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the weight sharding strategy. ### `weight` {#max.nn.legacy.linear.Linear.weight} > weight: [Weight](../../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (out\_dim, in\_dim). Model init transposes the weight and moves it to the target device. ### `weight_scale` {#max.nn.legacy.linear.Linear.weight_scale} > weight\_scale: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight\_scale to the target device if present. ## `MLP` {#max.nn.legacy.linear.MLP} > class max.nn.legacy.linear.MLP(dtype, quantization\_encoding, hidden\_dim, feed\_forward\_length, devices, linear\_cls=\, has\_bias=False, activation\_function='silu', float8\_config=None, dist\_gemm\_config=None, is\_sharding=False) Simple multi-layer perceptron composed of three [`Linear`](#max.nn.legacy.linear.Linear) layers. Defaults to SiLU activation function.
**Parameters:**
* dtype ([DType](../../dtype.md#max.dtype.DType)) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * feed\_forward\_length ([int](https://docs.python.org/3/library/functions.html#int)) * devices (Sequence\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * linear\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](#max.nn.legacy.linear.Linear)]) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * activation\_function ([str](https://docs.python.org/3/library/stdtypes.html#str)) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * dist\_gemm\_config ([DistributedGemmConfig](#max.nn.legacy.linear.DistributedGemmConfig) | None) * is\_sharding ([bool](https://docs.python.org/3/library/functions.html#bool))
### `shard()` {#max.nn.legacy.linear.MLP.shard} > shard(devices) Creates sharded views of this MLP across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MLP instances, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[MLP](#max.nn.legacy.linear.MLP)]
### `sharding_strategy` {#max.nn.legacy.linear.MLP.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the MLP sharding strategy. --- ## lora ## `AttentionWithRopeAndLoRA` {#max.nn.legacy.lora.AttentionWithRopeAndLoRA} > class max.nn.legacy.lora.AttentionWithRopeAndLoRA(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, max\_lora\_rank, max\_num\_loras, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None)
**Parameters:**
* rope ([RotaryEmbedding](rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) * kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * max\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int)) * max\_num\_loras ([int](https://docs.python.org/3/library/functions.html#int)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)] | None) * dtype ([DType](../../dtype.md#max.dtype.DType)) * linear\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../Linear.md#max.nn.Linear)]) * stacked\_qkv ([bool](https://docs.python.org/3/library/functions.html#bool)) * scale ([float](https://docs.python.org/3/library/functions.html#float) | None) * has\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * clip\_qkv ([float](https://docs.python.org/3/library/functions.html#float) | None)
### `rope` {#max.nn.legacy.lora.AttentionWithRopeAndLoRA.rope} > rope: [RotaryEmbedding](rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding) ## `LinearLoRA` {#max.nn.legacy.lora.LinearLoRA} > class max.nn.legacy.lora.LinearLoRA(in\_dim, out\_dim, max\_num\_loras, max\_lora\_rank, dtype, device, has\_lora\_bias=False, name=None, quantization\_encoding=None)
**Parameters:**
* in\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * out\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * max\_num\_loras ([int](https://docs.python.org/3/library/functions.html#int)) * max\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * device ([DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)) * has\_lora\_bias ([bool](https://docs.python.org/3/library/functions.html#bool)) * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * quantization\_encoding ([QuantizationEncoding](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | None)
### `set_lora_batch_info()` {#max.nn.legacy.lora.LinearLoRA.set_lora_batch_info} > set\_lora\_batch\_info(lora\_ids, lora\_ranks, lora\_grouped\_offsets, num\_active\_loras, lora\_end\_idx, batch\_seq\_len, lora\_ids\_kv, lora\_grouped\_offsets\_kv)
**Parameters:**
* lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * num\_active\_loras ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ids\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_grouped\_offsets\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
None
## `SupportsLoRA` {#max.nn.legacy.lora.SupportsLoRA} > class max.nn.legacy.lora.SupportsLoRA(\*args, \*\*kwargs) Base class for supporting LoRA functionality in Modules ### `set_lora_batch_info()` {#max.nn.legacy.lora.SupportsLoRA.set_lora_batch_info} > set\_lora\_batch\_info(lora\_ids, lora\_ranks, lora\_grouped\_offsets, num\_active\_loras, lora\_end\_idx, batch\_seq\_len, lora\_ids\_kv, lora\_grouped\_offsets\_kv)
**Parameters:**
* lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * num\_active\_loras ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_ids\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * lora\_grouped\_offsets\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Return type:**
None
--- ## moe Mixture of Experts (MoE) module. ## `Fp8Strategy` {#max.nn.legacy.moe.Fp8Strategy} > class max.nn.legacy.moe.Fp8Strategy(config, dtype) FP8 quantization for MoE.
**Parameters:**
* config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) * dtype ([DType](../../dtype.md#max.dtype.DType))
### `fused_silu_quantize()` {#max.nn.legacy.moe.Fp8Strategy.fused_silu_quantize} > fused\_silu\_quantize(gate\_up\_projs, input\_scales=None, expert\_inputs=()) Applies fused SiLU gate and returns quantized activations.
**Parameters:**
* gate\_up\_projs ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `grouped_matmul()` {#max.nn.legacy.moe.Fp8Strategy.grouped_matmul} > grouped\_matmul(weight, weight\_scales, expert\_scales=None, tokens\_padded\_per\_expert=False, expert\_inputs=()) Runs grouped FP8 matmul for the routed experts.
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * tokens\_padded\_per\_expert ([bool](https://docs.python.org/3/library/functions.html#bool)) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `prepare_weight_scales()` {#max.nn.legacy.moe.Fp8Strategy.prepare_weight_scales} > prepare\_weight\_scales(gate\_up, down, device) Passes FP8 weight scales through without reformatting.
**Parameters:**
* gate\_up ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * down ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `quantize()` {#max.nn.legacy.moe.Fp8Strategy.quantize} > quantize(tensor, group\_size, input\_scale=None) Quantizes activations to FP8 and returns (quantized, scales).
**Parameters:**
* tensor ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * group\_size ([int](https://docs.python.org/3/library/functions.html#int)) * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `MoE` {#max.nn.legacy.moe.MoE} > class max.nn.legacy.moe.MoE(devices, hidden\_dim, num\_experts, num\_experts\_per\_token, moe\_dim, gate\_cls=\, mlp\_cls=\, has\_shared\_experts=False, shared\_experts\_dim=0, ep\_size=1, dtype=bfloat16, apply\_router\_weight\_first=False, ep\_batch\_manager=None, float8\_config=None, is\_sharding=False) Implementation of Mixture of Experts (MoE).
**Parameters:**
* devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) * moe\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * gate\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [MoEGate](#max.nn.legacy.moe.MoEGate)]) * mlp\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [MLP](linear.md#max.nn.legacy.linear.MLP)]) * has\_shared\_experts ([bool](https://docs.python.org/3/library/functions.html#bool)) * shared\_experts\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * ep\_size ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * apply\_router\_weight\_first ([bool](https://docs.python.org/3/library/functions.html#bool)) * ep\_batch\_manager (EPBatchManager | None) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * is\_sharding ([bool](https://docs.python.org/3/library/functions.html#bool))
### `down_proj` {#max.nn.legacy.moe.MoE.down_proj} > property down\_proj: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `ep_batch_manager` {#max.nn.legacy.moe.MoE.ep_batch_manager} > property ep\_batch\_manager: EPBatchManager Get the expert parallel batch manager. ### `experts` {#max.nn.legacy.moe.MoE.experts} > experts: [LayerList](layer.md#max.nn.legacy.layer.LayerList) The list of experts. ### `gate_up_proj` {#max.nn.legacy.moe.MoE.gate_up_proj} > property gate\_up\_proj: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `shard()` {#max.nn.legacy.moe.MoE.shard} > shard(devices) Create sharded views of this MoE module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MoE instances, one for each device.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Self](https://docs.python.org/3/library/typing.html#typing.Self)]
### `shard_devices` {#max.nn.legacy.moe.MoE.shard_devices} > shard\_devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)] = [] The list of devices the MoE layer was sharded to. ### `shard_index` {#max.nn.legacy.moe.MoE.shard_index} > shard\_index: [int](https://docs.python.org/3/library/functions.html#int) = 0 The index of the current shard (if the MoE layer was sharded). ### `sharding_strategy` {#max.nn.legacy.moe.MoE.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the sharding strategy for the module. ## `MoEGate` {#max.nn.legacy.moe.MoEGate} > class max.nn.legacy.moe.MoEGate(devices, hidden\_dim, num\_experts, num\_experts\_per\_token, dtype, is\_sharding=False, linear\_cls=\) Gate module for MoE.
**Parameters:**
* devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * is\_sharding ([bool](https://docs.python.org/3/library/functions.html#bool)) * linear\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../Linear.md#max.nn.Linear)])
### `shard()` {#max.nn.legacy.moe.MoEGate.shard} > shard(devices) Create sharded views of this MoEGate module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MoEGate instances, one for each device.
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[MoEGate](#max.nn.legacy.moe.MoEGate)]
### `sharding_strategy` {#max.nn.legacy.moe.MoEGate.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the sharding strategy for the module. ## `MoEQuantized` {#max.nn.legacy.moe.MoEQuantized} > class max.nn.legacy.moe.MoEQuantized(devices, hidden\_dim, num\_experts, num\_experts\_per\_token, moe\_dim, gate\_cls=\, mlp\_cls=\, has\_shared\_experts=False, shared\_experts\_dim=0, ep\_size=1, dtype=bfloat16, apply\_router\_weight\_first=False, ep\_batch\_manager=None, float8\_config=None, is\_sharding=False) Mixture of Experts with FP8 or NVFP4 quantization.
**Parameters:**
* devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * hidden\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts ([int](https://docs.python.org/3/library/functions.html#int)) * num\_experts\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) * moe\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * gate\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [MoEGate](#max.nn.legacy.moe.MoEGate)]) * mlp\_cls ([Callable](../../graph/ops.md#max.graph.ops.Callable)\[..., [MLP](linear.md#max.nn.legacy.linear.MLP)]) * has\_shared\_experts ([bool](https://docs.python.org/3/library/functions.html#bool)) * shared\_experts\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * ep\_size ([int](https://docs.python.org/3/library/functions.html#int)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * apply\_router\_weight\_first ([bool](https://docs.python.org/3/library/functions.html#bool)) * ep\_batch\_manager (EPBatchManager | None) * float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) * is\_sharding ([bool](https://docs.python.org/3/library/functions.html#bool))
### `down_proj_scales` {#max.nn.legacy.moe.MoEQuantized.down_proj_scales} > property down\_proj\_scales: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) Returns stacked down-projection weight scales. ### `gate_up_proj_scales` {#max.nn.legacy.moe.MoEQuantized.gate_up_proj_scales} > property gate\_up\_proj\_scales: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) Returns stacked gate/up weight scales for grouped matmul. ## `Nvfp4Scales` {#max.nn.legacy.moe.Nvfp4Scales} > class max.nn.legacy.moe.Nvfp4Scales(gate\_up\_input, down\_input, gate\_up\_expert, down\_expert) Bundled scales for NVFP4 quantization.
**Parameters:**
* gate\_up\_input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * down\_input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * gate\_up\_expert ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * down\_expert ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
### `down_expert` {#max.nn.legacy.moe.Nvfp4Scales.down_expert} > down\_expert: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `down_input` {#max.nn.legacy.moe.Nvfp4Scales.down_input} > down\_input: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `gate_up_expert` {#max.nn.legacy.moe.Nvfp4Scales.gate_up_expert} > gate\_up\_expert: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `gate_up_input` {#max.nn.legacy.moe.Nvfp4Scales.gate_up_input} > gate\_up\_input: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ## `Nvfp4Strategy` {#max.nn.legacy.moe.Nvfp4Strategy} > class max.nn.legacy.moe.Nvfp4Strategy(config, dtype) NVFP4 quantization for MoE.
**Parameters:**
* config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) * dtype ([DType](../../dtype.md#max.dtype.DType))
### `fused_silu_quantize()` {#max.nn.legacy.moe.Nvfp4Strategy.fused_silu_quantize} > fused\_silu\_quantize(gate\_up\_projs, input\_scales=None, expert\_inputs=()) Applies SiLU gate then NVFP4 quantizes the result.
**Parameters:**
* gate\_up\_projs ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `grouped_matmul()` {#max.nn.legacy.moe.Nvfp4Strategy.grouped_matmul} > grouped\_matmul(weight, weight\_scales, expert\_scales=None, tokens\_padded\_per\_expert=False, expert\_inputs=()) Runs grouped NVFP4 matmul with per-expert scales.
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * tokens\_padded\_per\_expert ([bool](https://docs.python.org/3/library/functions.html#bool)) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `prepare_weight_scales()` {#max.nn.legacy.moe.Nvfp4Strategy.prepare_weight_scales} > prepare\_weight\_scales(gate\_up, down, device) Interleaves NVFP4 block scales for kernel layout.
**Parameters:**
* gate\_up ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * down ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `quantize()` {#max.nn.legacy.moe.Nvfp4Strategy.quantize} > quantize(tensor, group\_size, input\_scale=None) Quantizes activations to NVFP4 and returns (quantized, scales).
**Parameters:**
* tensor ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * group\_size ([int](https://docs.python.org/3/library/functions.html#int)) * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `QuantStrategy` {#max.nn.legacy.moe.QuantStrategy} > class max.nn.legacy.moe.QuantStrategy(\*args, \*\*kwargs) Quantization strategy for MoE layers. ### `fused_silu_quantize()` {#max.nn.legacy.moe.QuantStrategy.fused_silu_quantize} > fused\_silu\_quantize(gate\_up\_projs, input\_scales=None, expert\_inputs=()) Applies gating and quantizes activations for the down proj.
**Parameters:**
* gate\_up\_projs ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * input\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `grouped_matmul()` {#max.nn.legacy.moe.QuantStrategy.grouped_matmul} > grouped\_matmul(weight, weight\_scales, expert\_scales=None, tokens\_padded\_per\_expert=False, expert\_inputs=()) Runs grouped matmul for routed experts.
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * weight\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * expert\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) * tokens\_padded\_per\_expert ([bool](https://docs.python.org/3/library/functions.html#bool)) * expert\_inputs ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), ...])
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `prepare_weight_scales()` {#max.nn.legacy.moe.QuantStrategy.prepare_weight_scales} > prepare\_weight\_scales(gate\_up, down, device) Prepares weight scales for kernel consumption.
**Parameters:**
* gate\_up ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * down ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
### `quantize()` {#max.nn.legacy.moe.QuantStrategy.quantize} > quantize(tensor, group\_size, input\_scale=None) Quantizes activations and returns (quantized, scales).
**Parameters:**
* tensor ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * group\_size ([int](https://docs.python.org/3/library/functions.html#int)) * input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue), [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)]
## `silu_gate()` {#max.nn.legacy.moe.silu_gate} > max.nn.legacy.moe.silu\_gate(gate\_up\_projs, moe\_dim) Applies SiLU-gated activation: silu(gate) \* up.
**Parameters:**
* gate\_up\_projs ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) * moe\_dim ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
--- ## norm ## `ConstantLayerNorm` {#max.nn.legacy.norm.ConstantLayerNorm} > class max.nn.legacy.norm.ConstantLayerNorm(dims, device, dtype, eps=1e-05) Layer normalization block with constant gamma and beta values.
**Parameters:**
* dims ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), ...]) * device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * dtype ([DType](../../dtype.md#max.dtype.DType)) * eps ([float](https://docs.python.org/3/library/functions.html#float))
### `beta` {#max.nn.legacy.norm.ConstantLayerNorm.beta} > beta: npt.NDArray\[np.floating\[Any]] ### `device` {#max.nn.legacy.norm.ConstantLayerNorm.device} > device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) ### `dtype` {#max.nn.legacy.norm.ConstantLayerNorm.dtype} > dtype: [DType](../../dtype.md#max.dtype.DType) ### `eps` {#max.nn.legacy.norm.ConstantLayerNorm.eps} > eps: [float](https://docs.python.org/3/library/functions.html#float) = 1e-05 ### `gamma` {#max.nn.legacy.norm.ConstantLayerNorm.gamma} > gamma: npt.NDArray\[np.floating\[Any]] ## `GroupNorm` {#max.nn.legacy.norm.GroupNorm} > class max.nn.legacy.norm.GroupNorm(num\_groups, num\_channels, eps=1e-05, affine=True, device=gpu:0) Group normalization block. Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group\_norm.
**Parameters:**
* num\_groups ([int](https://docs.python.org/3/library/functions.html#int)) – Number of groups to separate the channels into * num\_channels ([int](https://docs.python.org/3/library/functions.html#int)) – Number of input channels * eps ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant added to denominator for numerical stability * affine ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, apply learnable affine transform parameters * device ([DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef))
## `LayerNorm` {#max.nn.legacy.norm.LayerNorm} > class max.nn.legacy.norm.LayerNorm(dims, devices, dtype, eps=1e-05, use\_bias=True) Layer normalization block.
**Parameters:**
* dims ([int](https://docs.python.org/3/library/functions.html#int)) * devices (Sequence\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)]) * dtype ([DType](../../dtype.md#max.dtype.DType)) * eps ([float](https://docs.python.org/3/library/functions.html#float)) * use\_bias ([bool](https://docs.python.org/3/library/functions.html#bool))
### `shard()` {#max.nn.legacy.norm.LayerNorm.shard} > shard(devices) Creates sharded views of this LayerNorm across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded LayerNorm instances, one for each device.
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[LayerNorm](#max.nn.legacy.norm.LayerNorm)]
### `sharding_strategy` {#max.nn.legacy.norm.LayerNorm.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the LayerNorm sharding strategy. ## `RMSNorm` {#max.nn.legacy.norm.RMSNorm} > class max.nn.legacy.norm.RMSNorm(dim, dtype, eps=1e-06, weight\_offset=0.0, multiply\_before\_cast=True) Computes the Root Mean Square normalization on inputs.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) – Size of last dimension of the expected input. * eps ([float](https://docs.python.org/3/library/functions.html#float)) – Value added to denominator for numerical stability. * weight\_offset ([float](https://docs.python.org/3/library/functions.html#float)) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0. * multiply\_before\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style). * dtype ([DType](../../dtype.md#max.dtype.DType))
### `shard()` {#max.nn.legacy.norm.RMSNorm.shard} > shard(devices) Creates sharded views of this RMSNorm across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded RMSNorm instances, one for each device.
**Return type:**
[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[RMSNorm](#max.nn.legacy.norm.RMSNorm)]
### `sharding_strategy` {#max.nn.legacy.norm.RMSNorm.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the RMSNorm sharding strategy. --- ## rotary_embedding The rope embedding used within the model. ## `DeepseekYarnRopeScalingParams` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams} > class max.nn.legacy.rotary\_embedding.DeepseekYarnRopeScalingParams(scaling\_factor: [float](https://docs.python.org/3/library/functions.html#float), original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int), beta\_fast: [int](https://docs.python.org/3/library/functions.html#int), beta\_slow: [int](https://docs.python.org/3/library/functions.html#int), mscale: [float](https://docs.python.org/3/library/functions.html#float), mscale\_all\_dim: [float](https://docs.python.org/3/library/functions.html#float))
**Parameters:**
* scaling\_factor ([float](https://docs.python.org/3/library/functions.html#float)) * original\_max\_position\_embeddings ([int](https://docs.python.org/3/library/functions.html#int)) * beta\_fast ([int](https://docs.python.org/3/library/functions.html#int)) * beta\_slow ([int](https://docs.python.org/3/library/functions.html#int)) * mscale ([float](https://docs.python.org/3/library/functions.html#float)) * mscale\_all\_dim ([float](https://docs.python.org/3/library/functions.html#float))
### `beta_fast` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.beta_fast} > beta\_fast: [int](https://docs.python.org/3/library/functions.html#int) Fast interpolation rate. ### `beta_slow` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.beta_slow} > beta\_slow: [int](https://docs.python.org/3/library/functions.html#int) Slow interpolation rate. ### `mscale` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.mscale} > mscale: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor for middle frequencies. ### `mscale_all_dim` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.mscale_all_dim} > mscale\_all\_dim: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor applied to all dimensions. ### `original_max_position_embeddings` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.original_max_position_embeddings} > original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int) Original maximum sequence length during training. ### `scaling_factor` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.scaling_factor} > scaling\_factor: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor for frequency interpolation. ## `DeepseekYarnRotaryEmbedding` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding} > class max.nn.legacy.rotary\_embedding.DeepseekYarnRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer. Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension of the model, not the hidden dimension.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) * scaling\_params ([DeepseekYarnRopeScalingParams](#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams) | None)
### `compute_scale()` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding.compute_scale} > compute\_scale(user\_scale=None)
**Parameters:**
user\_scale ([float](https://docs.python.org/3/library/functions.html#float) | None)
**Return type:**
[float](https://docs.python.org/3/library/functions.html#float)
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) for a given seq\_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
**Returns:**
The frequency tensor for complex exponentials with shape (max\_seq\_len, rope\_dim // 2, 2)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `scaling_params` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding.scaling_params} > scaling\_params: [DeepseekYarnRopeScalingParams](#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `DynamicRotaryEmbedding` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding} > class max.nn.legacy.rotary\_embedding.DynamicRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True) RotaryEmbedding with dynamic scaling support for long-context inference. Dynamically updates the inv\_freq and corresponding freqs\_cis buffer if the current sequence length exceeds the original max, or resets to the original high-precision version for short sequences.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool))
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes freqs\_cis dynamically using the current self.inv\_freq.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `maybe_update_freqs()` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding.maybe_update_freqs} > maybe\_update\_freqs(position\_ids) Update freqs\_cis if the sequence exceeds max\_seq\_len\_cached, or revert to the original version if back below the threshold.
**Parameters:**
position\_ids (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray))
**Return type:**
None
## `LinearScalingParams` {#max.nn.legacy.rotary_embedding.LinearScalingParams} > class max.nn.legacy.rotary\_embedding.LinearScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float))
**Parameters:**
factor ([float](https://docs.python.org/3/library/functions.html#float))
### `factor` {#max.nn.legacy.rotary_embedding.LinearScalingParams.factor} > factor: [float](https://docs.python.org/3/library/functions.html#float) Main scaling factor for the frequency components of the rope. ## `Llama3RopeScalingParams` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams} > class max.nn.legacy.rotary\_embedding.Llama3RopeScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float), low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int))
**Parameters:**
* factor ([float](https://docs.python.org/3/library/functions.html#float)) * low\_freq\_factor ([float](https://docs.python.org/3/library/functions.html#float)) * high\_freq\_factor ([float](https://docs.python.org/3/library/functions.html#float)) * orig\_max\_position ([int](https://docs.python.org/3/library/functions.html#int))
### `factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.factor} > factor: [float](https://docs.python.org/3/library/functions.html#float) Main scaling factor for the frequency components of the rope. ### `high_freq_factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.high_freq_factor} > high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float) Factor to scale the high frequency components of the rope. ### `low_freq_factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.low_freq_factor} > low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float) Factor to scale the low frequency components of the rope. ### `orig_max_position` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.orig_max_position} > orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int) The original maximum position length supported by the model. ## `Llama3RotaryEmbedding` {#max.nn.legacy.rotary_embedding.Llama3RotaryEmbedding} > class max.nn.legacy.rotary\_embedding.Llama3RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) RotaryEmbedding for Llama3 that takes rope scaling into account.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) * scaling\_params ([Llama3RopeScalingParams](#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams) | None)
### `scaling_params` {#max.nn.legacy.rotary_embedding.Llama3RotaryEmbedding.scaling_params} > scaling\_params: [Llama3RopeScalingParams](#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None Scaling parameters to enable llama to function with a longer context length. ## `LongRoPERotaryEmbedding` {#max.nn.legacy.rotary_embedding.LongRoPERotaryEmbedding} > class max.nn.legacy.rotary\_embedding.LongRoPERotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) Rotary position embedding with LongRoPE scaling for Phi-3.5 models.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) * scaling\_params ([LongRoPEScalingParams](#max.nn.legacy.rotary_embedding.LongRoPEScalingParams) | None)
### `compute_scale()` {#max.nn.legacy.rotary_embedding.LongRoPERotaryEmbedding.compute_scale} > compute\_scale(user\_scale=None) Compute attention scale with LongRoPE adjustment.
**Parameters:**
user\_scale ([float](https://docs.python.org/3/library/functions.html#float) | None)
**Return type:**
[float](https://docs.python.org/3/library/functions.html#float)
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.LongRoPERotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) with LongRoPE scaling. Creates a “stitched” table where: * Positions 0 to original\_max\_position use short\_factor * Positions from original\_max\_position onwards use long\_factor
**Returns:**
The frequency tensor for complex exponentials with shape (max\_seq\_len \* 2, head\_dim / 2, 2)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
## `LongRoPEScalingParams` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams} > class max.nn.legacy.rotary\_embedding.LongRoPEScalingParams(short\_factor, long\_factor, original\_max\_position, max\_position\_embeddings) Parameters for LongRoPE scaling as used in Phi-3.5 models.
**Parameters:**
* short\_factor ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]) * long\_factor ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]) * original\_max\_position ([int](https://docs.python.org/3/library/functions.html#int)) * max\_position\_embeddings ([int](https://docs.python.org/3/library/functions.html#int))
### `long_factor` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.long_factor} > long\_factor: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)] Scaling factors for long sequences (can be much larger). ### `max_position_embeddings` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.max_position_embeddings} > max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int) Current max position embeddings after scaling. ### `original_max_position` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.original_max_position} > original\_max\_position: [int](https://docs.python.org/3/library/functions.html#int) Original max position embeddings the model was trained with. ### `short_factor` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.short_factor} > short\_factor: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)] Scaling factors for short sequences (typically close to 1.0). ## `RotaryEmbedding` {#max.nn.legacy.rotary_embedding.RotaryEmbedding} > class max.nn.legacy.rotary\_embedding.RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True) RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool))
### `compute_scale()` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.compute_scale} > compute\_scale(user\_scale=None)
**Parameters:**
user\_scale ([float](https://docs.python.org/3/library/functions.html#float) | None)
**Return type:**
[float](https://docs.python.org/3/library/functions.html#float)
### `dim` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.dim} > dim: [int](https://docs.python.org/3/library/functions.html#int) ### `freqs_cis` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.freqs_cis} > property freqs\_cis: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) ### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) for a given seq\_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
**Returns:**
The frequency tensor for complex exponentials with shape (max\_seq\_len \* 2, head\_dim / 2, 2)
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `head_dim` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.head_dim} > head\_dim: [int](https://docs.python.org/3/library/functions.html#int) head\_dim = dim // n\_heads if not specified in the config. ### `interleaved` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.interleaved} > interleaved: [bool](https://docs.python.org/3/library/functions.html#bool) = True ### `max_seq_len` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.max_seq_len} > max\_seq\_len: [int](https://docs.python.org/3/library/functions.html#int) The maximum sequence length for model’s input. ### `n_heads` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.n_heads} > n\_heads: [int](https://docs.python.org/3/library/functions.html#int) ### `theta` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.theta} > theta: [float](https://docs.python.org/3/library/functions.html#float) Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings. ## `YarnRotaryEmbedding` {#max.nn.legacy.rotary_embedding.YarnRotaryEmbedding} > class max.nn.legacy.rotary\_embedding.YarnRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) Generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer. This implementation provides YARN scaling for models that require it, with configurable parameters for beta\_fast, beta\_slow, and scaling factor.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * theta ([float](https://docs.python.org/3/library/functions.html#float)) * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) * head\_dim ([int](https://docs.python.org/3/library/functions.html#int) | None) * \_freqs\_cis (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray) | None) * interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) * scaling\_params ([YarnScalingParams](#max.nn.legacy.rotary_embedding.YarnScalingParams) | None)
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.YarnRotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) with YARN scaling applied.
**Return type:**
[TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `scaling_params` {#max.nn.legacy.rotary_embedding.YarnRotaryEmbedding.scaling_params} > scaling\_params: [YarnScalingParams](#max.nn.legacy.rotary_embedding.YarnScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `YarnScalingParams` {#max.nn.legacy.rotary_embedding.YarnScalingParams} > class max.nn.legacy.rotary\_embedding.YarnScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float), beta\_fast: [float](https://docs.python.org/3/library/functions.html#float), beta\_slow: [float](https://docs.python.org/3/library/functions.html#float), original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int), truncate: [bool](https://docs.python.org/3/library/functions.html#bool))
**Parameters:**
* factor ([float](https://docs.python.org/3/library/functions.html#float)) * beta\_fast ([float](https://docs.python.org/3/library/functions.html#float)) * beta\_slow ([float](https://docs.python.org/3/library/functions.html#float)) * original\_max\_position\_embeddings ([int](https://docs.python.org/3/library/functions.html#int)) * truncate ([bool](https://docs.python.org/3/library/functions.html#bool))
### `beta_fast` {#max.nn.legacy.rotary_embedding.YarnScalingParams.beta_fast} > beta\_fast: [float](https://docs.python.org/3/library/functions.html#float) Yarn parameter for fast frequencies. ### `beta_slow` {#max.nn.legacy.rotary_embedding.YarnScalingParams.beta_slow} > beta\_slow: [float](https://docs.python.org/3/library/functions.html#float) Yarn parameter for slow frequencies. ### `factor` {#max.nn.legacy.rotary_embedding.YarnScalingParams.factor} > factor: [float](https://docs.python.org/3/library/functions.html#float) Main scaling factor for the frequency components of the rope. ### `original_max_position_embeddings` {#max.nn.legacy.rotary_embedding.YarnScalingParams.original_max_position_embeddings} > original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int) The original maximum position length supported by the model. ### `truncate` {#max.nn.legacy.rotary_embedding.YarnScalingParams.truncate} > truncate: [bool](https://docs.python.org/3/library/functions.html#bool) Whether to truncate the frequencies or not. --- ## sampling Sampling custom ops. ## `MinPSampler` {#max.nn.legacy.sampling.MinPSampler} > class max.nn.legacy.sampling.MinPSampler(dtype, shape, temperature=1) A min\_p sampler.
**Parameters:**
* dtype ([DType](../../dtype.md#max.dtype.DType)) * shape ([Shape](../../graph/shape.md#max.graph.shape.Shape)) * temperature ([float](https://docs.python.org/3/library/functions.html#float))
### `dtype` {#max.nn.legacy.sampling.MinPSampler.dtype} > dtype: [DType](../../dtype.md#max.dtype.DType) ### `min_p` {#max.nn.legacy.sampling.MinPSampler.min_p} > min\_p: [float](https://docs.python.org/3/library/functions.html#float) ### `shape` {#max.nn.legacy.sampling.MinPSampler.shape} > shape: [Shape](../../graph/shape.md#max.graph.shape.Shape) ### `temperature` {#max.nn.legacy.sampling.MinPSampler.temperature} > temperature: [float](https://docs.python.org/3/library/functions.html#float) ## `RejectionSampler` {#max.nn.legacy.sampling.RejectionSampler} > class max.nn.legacy.sampling.RejectionSampler(device, top\_k=1, top\_p=1, temperature=1.0, seed=0, eps=1e-05) A simple rejection sampler.
**Parameters:**
* device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * top\_k ([int](https://docs.python.org/3/library/functions.html#int)) * top\_p ([float](https://docs.python.org/3/library/functions.html#float)) * temperature ([float](https://docs.python.org/3/library/functions.html#float)) * seed ([int](https://docs.python.org/3/library/functions.html#int)) * eps ([float](https://docs.python.org/3/library/functions.html#float))
## `RejectionSamplerWithResiduals` {#max.nn.legacy.sampling.RejectionSamplerWithResiduals} > class max.nn.legacy.sampling.RejectionSamplerWithResiduals(device, top\_k=1, temperature=1.0, eps=1e-10, seed=0, debug=False) A simple rejection sampler.
**Parameters:**
* device ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) * top\_k ([int](https://docs.python.org/3/library/functions.html#int)) * temperature ([float](https://docs.python.org/3/library/functions.html#float)) * eps ([float](https://docs.python.org/3/library/functions.html#float)) * seed ([int](https://docs.python.org/3/library/functions.html#int)) * debug ([bool](https://docs.python.org/3/library/functions.html#bool))
--- ## sequential A General sequential layer, each layer is executed with the outputs of the previous. ## `Sequential` {#max.nn.legacy.sequential.Sequential} > class max.nn.legacy.sequential.Sequential(layers) A sequential stack of layers where each layer is called by the outputs of the previous layer.
**Parameters:**
layers ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Layer](layer.md#max.nn.legacy.layer.Layer)])
--- ## distributed_transformer ## `DistributedTransformer` {#max.nn.legacy.transformer.distributed_transformer.DistributedTransformer} > class max.nn.legacy.transformer.distributed\_transformer.DistributedTransformer(dim, n\_heads, layers, norm, output, embedding, kv\_params, devices, rope, return\_logits=ReturnLogits.LAST\_TOKEN, use\_subgraphs=False, subgraph\_layer\_groups=None, logits\_scaling=1.0) Transformer model consisting for TransformerBlock layers.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * layers ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DistributedTransformerBlock](#max.nn.legacy.transformer.distributed_transformer.DistributedTransformerBlock)]) * norm ([ShardableCallable](#max.nn.legacy.transformer.distributed_transformer.ShardableCallable)) * output ([ColumnParallelLinear](../linear.md#max.nn.legacy.linear.ColumnParallelLinear)) * embedding ([VocabParallelEmbedding](../embedding.md#max.nn.legacy.embedding.VocabParallelEmbedding)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)]) * rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * return\_logits ([ReturnLogits](transformer.md#max.nn.legacy.transformer.transformer.ReturnLogits)) * use\_subgraphs ([bool](https://docs.python.org/3/library/functions.html#bool)) * subgraph\_layer\_groups ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]] | None) * logits\_scaling ([float](https://docs.python.org/3/library/functions.html#float))
## `DistributedTransformerBlock` {#max.nn.legacy.transformer.distributed_transformer.DistributedTransformerBlock} > class max.nn.legacy.transformer.distributed\_transformer.DistributedTransformerBlock(attention, mlp, attention\_norm, mlp\_norm, devices, distributed\_gemm\_config=None) Stack of Attention, FeedForward, and RMSNorm layers.
**Parameters:**
* attention ([Module](../layer.md#max.nn.legacy.layer.Module)) * mlp ([ShardableCallable](#max.nn.legacy.transformer.distributed_transformer.ShardableCallable)) * attention\_norm ([ShardableCallable](#max.nn.legacy.transformer.distributed_transformer.ShardableCallable)) * mlp\_norm ([ShardableCallable](#max.nn.legacy.transformer.distributed_transformer.ShardableCallable)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)]) * distributed\_gemm\_config ([DistributedGemmConfig](../linear.md#max.nn.legacy.linear.DistributedGemmConfig) | None)
## `ShardableCallable` {#max.nn.legacy.transformer.distributed_transformer.ShardableCallable} > class max.nn.legacy.transformer.distributed\_transformer.ShardableCallable(\*args, \*\*kwargs) ## `distribute_value()` {#max.nn.legacy.transformer.distributed_transformer.distribute_value} > max.nn.legacy.transformer.distributed\_transformer.distribute\_value(v, devices)
**Parameters:**
* v ([TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]
## `forward_sharded_layers()` {#max.nn.legacy.transformer.distributed_transformer.forward_sharded_layers} > max.nn.legacy.transformer.distributed\_transformer.forward\_sharded\_layers(layers, xs) Forward pass through sharded layers.
**Parameters:**
* layers ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](../../../graph/ops.md#max.graph.ops.Callable)\[\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)], [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]]) – Sequence of callable layers that return TensorValue * xs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]) – Input tensors, one per layer
**Returns:**
List of output tensors from each layer
**Raises:**
[AssertionError](https://docs.python.org/3/library/exceptions.html#AssertionError) – If the number of layers and input tensors don’t match
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]
## `take()` {#max.nn.legacy.transformer.distributed_transformer.take} > max.nn.legacy.transformer.distributed\_transformer.take(it, n) Return the next n items from it as a list.
**Parameters:**
* it ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](../../../graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * n ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Value](../../../graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
--- ## transformer (Transformer) Legacy transformer building blocks for graph-based neural networks. ## Modules * [`distributed_transformer`](/max/api/python/nn/legacy/transformer/distributed_transformer): Distributed transformer implementation. * [`transformer`](/max/api/python/nn/legacy/transformer/transformer): Transformer block implementation. --- ## transformer (3) ## `ReturnHiddenStates` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates} > class max.nn.legacy.transformer.transformer.ReturnHiddenStates(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ALL` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates.ALL} > ALL = 'all' ### `ALL_NORMALIZED` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates.ALL_NORMALIZED} > ALL\_NORMALIZED = 'all\_normalized' ### `LAST` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates.LAST} > LAST = 'last' ### `LAST_NORMALIZED` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates.LAST_NORMALIZED} > LAST\_NORMALIZED = 'last\_normalized' ### `NONE` {#max.nn.legacy.transformer.transformer.ReturnHiddenStates.NONE} > NONE = 'none' ## `ReturnLogits` {#max.nn.legacy.transformer.transformer.ReturnLogits} > class max.nn.legacy.transformer.transformer.ReturnLogits(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ALL` {#max.nn.legacy.transformer.transformer.ReturnLogits.ALL} > ALL = 'all' ### `LAST_TOKEN` {#max.nn.legacy.transformer.transformer.ReturnLogits.LAST_TOKEN} > LAST\_TOKEN = 'last\_token' ### `VARIABLE` {#max.nn.legacy.transformer.transformer.ReturnLogits.VARIABLE} > VARIABLE = 'variable' ## `Transformer` {#max.nn.legacy.transformer.transformer.Transformer} > class max.nn.legacy.transformer.transformer.Transformer(dim, n\_heads, layers, norm, output, embedding, kv\_params, rope, return\_logits=ReturnLogits.LAST\_TOKEN, return\_hidden\_states=ReturnHiddenStates.NONE, embedding\_multiplier=1.0, logits\_scaling=1.0) Transformer model consisting for TransformerBlock layers.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) * layers ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Block]) * norm ([Layer](../layer.md#max.nn.legacy.layer.Layer)) * output ([Linear](../../Linear.md#max.nn.Linear)) * embedding ([Embedding](../../Embedding.md#max.nn.Embedding)) * kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) * rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) * return\_logits ([ReturnLogits](#max.nn.legacy.transformer.transformer.ReturnLogits)) * return\_hidden\_states ([ReturnHiddenStates](#max.nn.legacy.transformer.transformer.ReturnHiddenStates)) * embedding\_multiplier ([float](https://docs.python.org/3/library/functions.html#float)) * logits\_scaling ([float](https://docs.python.org/3/library/functions.html#float))
## `TransformerBlock` {#max.nn.legacy.transformer.transformer.TransformerBlock} > class max.nn.legacy.transformer.transformer.TransformerBlock(attention, mlp, attention\_norm, mlp\_norm, residual\_multiplier=1.0) Stack of Attention, FeedForward, and RMSNorm layers.
**Parameters:**
* attention ([Module](../layer.md#max.nn.legacy.layer.Module)) * mlp ([Layer](../layer.md#max.nn.legacy.layer.Layer)) * attention\_norm ([Layer](../layer.md#max.nn.legacy.layer.Layer)) * mlp\_norm ([Layer](../layer.md#max.nn.legacy.layer.Layer)) * residual\_multiplier ([float](https://docs.python.org/3/library/functions.html#float))
--- ## module Base classes and decorators for building neural network modules in MAX. ## `Module` {#max.nn.module.Module} > class max.nn.module.Module The core unit of composition for modeling in MAX. Informally, a `Module` is a container class. It can contain other `Module` instances, tensors (the `Module`’s “local parameters”) or other arbitrary Python data. A `Module` also has a `forward()` method which defines how the `Module` computes its output. In the simplest case this is a function from one tensor to another tensor. Users call the module using `__call__()` which internally invokes `forward()`. Formally modules form a tree, and subtrees of modules can be manipulated directly. A `Module` may also be thought of as a closure, where the parameters form the data of the closure and `forward()` is the application of the closure. Users who do not use a Python type checker, or use lax settings for their type checker, may inherit from `Module` without parameters. Users who use a type checker with stricter settings (including MAX internal code) should specify explicit types for full type checking: ```default class Linear(Module[[Tensor], Tensor]): def forward(self, x: Tensor) -> Tensor: return x @ self.weight.T + self.bias ``` **Terminology:** * A “child” of a `Module` is a sub-`Module` stored directly on that `Module`. * A “descendant” of a `Module` is one of its children, or one of their descendants. * A “parameter” is a tensor storing data on the `Module` or one of its descendants. * The “qualified path” of a descendant is a period-separated string of the names of the child module attributes which lead to that descendant module, for instance `child.sub.last`. * The “qualified path” of a parameter is the qualified path of the descendant directly holding that parameter, followed by a final path component for the attribute name of the tensor. For instance `weight` for a local parameter, or `child.sub.last.weight` for a descendant’s parameter. ```python from max.tensor import Tensor from max.nn import Module, module_dataclass @module_dataclass class Linear(Module): weight: Tensor bias: Tensor | int = 0 def forward(self, x: Tensor) -> Tensor: return x @ self.weight.T + self.bias linear = Linear(Tensor.zeros([5, 4])) print(linear) print(linear(Tensor.constant([1, 2, 3, 4]))) ``` ### `apply_to_local_parameters()` {#max.nn.module.Module.apply_to_local_parameters} > apply\_to\_local\_parameters(f) Applies a transformation to each local parameter tensor on the `Module`. The transformation is applied in-place, updating the module’s values. It will not be applied to descendant’s parameters. For example: ```python from max.driver import Accelerator from max.nn import Linear model = Linear(2, 3) model.apply_to_parameters(lambda _, t: t.to(Accelerator())) ```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) – The transformation to apply to each local parameter. The transformation takes two arguments, a name and a tensor: * The name is the attribute name of the parameter on the module. * The tensor is the current value of that parameter. The return value of this function is the new value that will replace the value at that name.
**Return type:**
None
### `apply_to_parameters()` {#max.nn.module.Module.apply_to_parameters} > apply\_to\_parameters(f) Applies a transformation to all parameters in the module hierarchy. This method traverses the module tree and applies the transformation function to each parameter in-place, updating both the current module’s parameters and all nested sub-module parameters. The transformation receives the parameter’s qualified name (dot-separated path) and current tensor value. Transfer all parameters to accelerator: ```python from max.driver import Accelerator from max.tensor import Tensor from max.nn import Module, module_dataclass, Linear @module_dataclass class MLP(Module): fc1: Linear fc2: Linear def forward(self, x: Tensor) -> Tensor: return self.fc2(self.fc1(x)) model = MLP( fc1=Linear(10, 20), fc2=Linear(20, 5) ) model.apply_to_parameters(lambda name, t: t.to(Accelerator())) ```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) – Transformation function taking `(name, tensor)` and returning the transformed tensor. Parameters: * `name` ([`str`](https://docs.python.org/3/library/stdtypes.html#str)): Qualified dot-separated path of the parameter (e.g., `"fc1.weight"`, `"encoder.layer2.bias"`) * `tensor` (`Tensor`): Current value of the parameter Returns the new tensor value to replace the parameter.
**Return type:**
None
### `children` {#max.nn.module.Module.children} > property children: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Iterates over the direct child modules of the `Module`.
**Yields:**
`(name, module)` pairs, where `name` is the attribute name of the child on the module.
### `compile()` {#max.nn.module.Module.compile} > compile(\*input\_types, weights=None) Compiles the module to an optimized executable through graph tracing. This method performs symbolic tracing of the module’s `forward` method to construct a MAX `Graph`, which is then compiled and optimized for efficient execution on CPU, GPU, or other accelerators. The compilation process: 1. Creates symbolic `Tensor` instances based on provided type specifications 2. Executes `forward` with symbolic tensors to record operations 3. Constructs a `Graph` representing the computation 4. Includes all module parameters as weights in the graph 5. Compiles and optimizes the graph for target hardware 6. Returns an executable function with the same signature as `forward` The input type specifications must match the signature of `forward`. Use positional arguments for positional parameters. Basic compilation with fixed shapes: ```python from max.dtype import DType from max.tensor import Tensor, TensorType, defaults from max.nn import Module, module_dataclass @module_dataclass class Linear(Module): weight: Tensor bias: Tensor def forward(self, x: Tensor) -> Tensor: return x @ self.weight.T + self.bias linear = Linear( weight=Tensor.zeros([10, 5]), bias=Tensor.zeros([10]) ) # Compile with fixed input shape _, device = defaults() input_type = TensorType(DType.float32, [3, 5], device=device) model = linear.compile(input_type) # Execute compiled model input_data = Tensor.ones([3, 5], dtype=DType.float32) result = model(input_data) print(result) ```
**Parameters:**
* \*input\_types ([Type](../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Type specifications for each positional argument to `forward`. Must match the number and order of arguments. Each should be a `max.graph.Type` (typically `TensorType`) describing the shape and dtype. * weights ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] | None) – Mapping of parameter names to weight data. Weights should be on CPU and will be transferred to the target device as part of model initialization. If not passed, the model’s parameters will be used as the weights.
**Returns:**
Callable\[…, Any] A compiled executable function with the same signature as `forward`. This function runs the optimized graph and returns results with the same structure as `forward` (single `Tensor` or tuple of tensors).
**Raises:**
* [TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If input types don’t match `forward` signature or if operations in `forward` cannot be traced. * [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If graph construction fails due to incompatible operations or parameter access issues.
**Return type:**
[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[…], [Any](https://docs.python.org/3/library/typing.html#typing.Any)]
### `descendants` {#max.nn.module.Module.descendants} > property descendants: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Iterates over the `Module`’s descendant modules.
**Yields:**
`(name, module)` pairs, where `name` is the qualified path of the descendant with respect to the module.
### `forward()` {#max.nn.module.Module.forward} > forward(\*args, \*\*kwargs) Defines the computation performed by the module. Users must override this method in their subclass to define the module’s computation.
**Parameters:**
* \*args (\~\_P) – Positional arguments for the computation. * \*\*kwargs (\~\_P) – Keyword arguments for the computation.
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
\_R
### `load_state()` {#max.nn.module.Module.load_state} > load\_state(lookup) Replaces each parameter in the module and its descendants. The transformation is applied in-place, updating the module’s values and those of its descendants. For example, if we have a model with two parameters, `weight` and `bias`, we can load the state of the model from a dictionary with the following code: ```python from max.tensor import Tensor from max.nn import Linear model = Linear(2, 3) weights = { "weight": Tensor.zeros([3, 2]), "bias": Tensor.zeros([3]), } model.load_state(lambda name, _: weights[name]) ``` The lookup is defined as a function rather than a dictionary, allowing for functional remapping of names during this process to account for differences in common weight naming and storage conventions. For instance, certain representations may not store weights as transposed, or may need to be quantized, or split out from a shared qkv block, or may just have slightly different names or paths. This can also be used for instance to provide a default value for initializing LoRA weights.
**Parameters:**
lookup ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [DLPackArray](../driver.md#max.driver.DLPackArray)]) – The lookup function for each parameter: * The first argument is the qualified name of the parameter with respect to the module on which `load_state()` was called. * The second argument is the existing tensor value. * The return value of this function is the new value that will replace the value at that name in the module tree.
### `load_state_dict()` {#max.nn.module.Module.load_state_dict} > load\_state\_dict(state, strict=True) Loads parameter values from a dictionary into the module hierarchy. This method updates all module parameters in-place by loading values from the provided state dictionary. The dictionary maps qualified parameter names (dot-separated paths like `"fc1.weight"`) to tensor values. The `strict` mode (default) ensures all weights in the dictionary are actually used, catching errors from mismatched architectures or incorrect weight names. For example, the following loads weights from a dictionary into a model: ```python from max.tensor import Tensor from max.nn import Module, module_dataclass @module_dataclass class Linear(Module): weight: Tensor bias: Tensor def forward(self, x: Tensor) -> Tensor: return x @ self.weight.T + self.bias model = Linear( weight=Tensor.zeros([10, 5]), bias=Tensor.zeros([10]) ) # Load weights from dictionary weights = { "weight": Tensor.zeros([10, 5]), "bias": Tensor.zeros([10]), } model.load_state(lambda name, _: weights[name]) ```
**Parameters:**
* state ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)]) – Dictionary mapping qualified parameter names to tensor values. Keys should match the names from [`Module.parameters`](#max.nn.module.Module.parameters) property. Values should be DLPack-compatible arrays or `Tensor` objects. Their shapes and dtypes must match the existing parameters with the corresponding name, but they may be on a different device. In the case that the new value has a different device, it will be copied to the same device as the existing value, and the parameter will be set to the new copy. * strict ([bool](https://docs.python.org/3/library/functions.html#bool)) – If [`True`](https://docs.python.org/3/library/constants.html#True) (default), verify that all keys in `state` are used (i.e., match actual parameters). If [`False`](https://docs.python.org/3/library/constants.html#False), silently ignore extra keys that don’t match any parameters.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `strict=True` and some weights in `state` don’t match any model parameters (indicates architecture mismatch or incorrect weight names). * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a loaded tensor has a different dtype or shape than the existing parameter. * [KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If a required parameter name in the model is missing from `state` (regardless of `strict` setting).
**Return type:**
None
### `local_parameters` {#max.nn.module.Module.local_parameters} > property local\_parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)]] Iterates over the local parameters of the `Module`.
**Yields:**
`(name, tensor)` pairs, where `name` is the attribute name of the tensor on the module.
### `map_parameters()` {#max.nn.module.Module.map_parameters} > map\_parameters(f) Creates a new `Module` with its parameters transformed by the function. The transformation is functional rather than in-place. The module is deep-copied; its descendants are also replaced via the same transform without affecting the original module. For example: ```python from max.driver import Accelerator from max.nn import Linear model = Linear(2, 3) model_on_gpu = model.map_parameters(lambda _, t: t.to(Accelerator())) ```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) – The transformation to apply to each parameter. The transformation takes two arguments, a name and a tensor: * The name is the qualified name of the parameter with respect to the module on which `map_parameters()` was called. * The tensor is the current value of that parameter. The return value of this function is the new value that will replace the value at that name in the module tree.
**Returns:**
A new module tree of the same type resulting from mapping the transformation over all model parameters.
**Return type:**
[Self](https://docs.python.org/3/library/typing.html#typing.Self)
### `parameters` {#max.nn.module.Module.parameters} > property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)]] Iterates over all parameters in this module and its sub-modules. This property performs a depth-first traversal of the module hierarchy, yielding each parameter tensor with its qualified name. The qualified name uses dot-notation to represent the module tree structure (e.g., `"encoder.layer1.weight"`). Parameters are yielded in depth-first order: first the current module’s direct parameters, then recursively each sub-module’s parameters. Counting total parameters: ```python from max.tensor import Tensor from max.nn import Module, module_dataclass from max.nn import Linear @module_dataclass class MLP(Module): fc1: Linear fc2: Linear def forward(self, x: Tensor) -> Tensor: return self.fc2(self.fc1(x)) model = MLP( fc1=Linear(10, 20), fc2=Linear(20, 5) ) # Count parameters total_params = sum( param.num_elements() for name, param in model.parameters ) print(f"Total parameters: {total_params}") ```
**Yields:**
`(name, parameter)` tuples where `name` is the dot-separated qualified path of the parameter and `parameter` is the `Tensor`.
### `to()` {#max.nn.module.Module.to} > to(device) Updates the module’s parameters, transferring them to the specified device. ```python from max.driver import CPU from max.nn import Linear model = Linear(2, 3) model.to(CPU()) ```
**Parameters:**
device ([Device](../driver.md#max.driver.Device)) – The device to which all model parameters will be transferred.
**Returns:**
A reference to the model. The transfer is applied mutably; internal parameters are updated to be transferred to the specified device.
**Return type:**
[Self](https://docs.python.org/3/library/typing.html#typing.Self)
## `module_dataclass()` {#max.nn.module.module_dataclass} > max.nn.module.module\_dataclass(cls=None, /, \*, repr=False, \*\*kwargs) Converts a class into a MAX module with automatic parameter tracking. This decorator enables a regular Python class to function as a [`Module`](#max.nn.module.Module), providing automatic discovery and registration of parameters (Tensor fields) and nested modules. The decorated class gains all capabilities of [`Module`](#max.nn.module.Module), including parameter iteration, graph compilation via [`Module.compile()`](#max.nn.module.Module.compile), and hierarchical module composition. The decorator applies Python’s `@dataclass` decorator internally while preserving [`Module`](#max.nn.module.Module)’s specialized `__repr__` method for better debugging experience when printing module structures.
**Parameters:**
* cls ([type](https://docs.python.org/3/library/functions.html#type)\[[Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – The class to decorate. Must define a `forward` method. When [`None`](https://docs.python.org/3/library/constants.html#None), returns a decorator function (supports using `@module_dataclass` with or without parentheses). * repr ([bool](https://docs.python.org/3/library/functions.html#bool)) – If [`True`](https://docs.python.org/3/library/constants.html#True), use dataclass’s default `__repr__` instead of [`Module`](#max.nn.module.Module)’s rich representation. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * \*\*kwargs – Additional keyword arguments forwarded to Python’s `@dataclass` decorator (e.g., `frozen`, `eq`).
**Returns:**
The decorated class as a [`Module`](#max.nn.module.Module) subclass with automatic parameter tracking and graph compilation capabilities. When `cls` is [`None`](https://docs.python.org/3/library/constants.html#None), returns a decorator function.
--- ## GemmaRMSNorm ## `GemmaRMSNorm` {#max.nn.norm.rms_norm.GemmaRMSNorm} > class max.nn.norm.rms\_norm.GemmaRMSNorm(dim, eps=1e-06) Computes the Root Mean Square normalization on inputs. Differences to traditional RMSNorm: * x \* (1 + w) instead of x \* w. * (x \* w).to(orig\_dtype) instead of x.to(orig\_dtype) \* w.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * eps ([float](https://docs.python.org/3/library/functions.html#float))
### `forward()` {#max.nn.norm.rms_norm.GemmaRMSNorm.forward} > forward(x) Defines the computation performed by the module. Users must override this method in their subclass to define the module’s computation.
**Parameters:**
* \*args – Positional arguments for the computation. * \*\*kwargs – Keyword arguments for the computation. * x ([Tensor](../../tensor.md#max.tensor.Tensor))
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
--- ## norm (Norm) Normalization layers for neural networks. ## Modules * [`rms_norm`](/max/api/python/nn/norm/rms_norm): Root Mean Square normalization layer. ## Classes * [`GemmaRMSNorm`](/max/api/python/nn/norm/GemmaRMSNorm): RMS normalization optimized for Gemma models. :::note Note For legacy normalization layers (LayerNorm, GroupNorm), see [legacy/norm](/max/api/python/nn/legacy/norm). ::: --- ## rms_norm Root mean square layer normalization. ## `RMSNorm` {#max.nn.norm.rms_norm.RMSNorm} > class max.nn.norm.rms\_norm.RMSNorm(dim, eps=1e-06) Computes the Root Mean Square normalization on inputs.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) * eps ([float](https://docs.python.org/3/library/functions.html#float))
### `dim` {#max.nn.norm.rms_norm.RMSNorm.dim} > property dim: [Dim](../../graph/dim.md#max.graph.dim.Dim) ### `eps` {#max.nn.norm.rms_norm.RMSNorm.eps} > eps: [float](https://docs.python.org/3/library/functions.html#float) ### `forward()` {#max.nn.norm.rms_norm.RMSNorm.forward} > forward(x) Defines the computation performed by the module. Users must override this method in their subclass to define the module’s computation.
**Parameters:**
* \*args – Positional arguments for the computation. * \*\*kwargs – Keyword arguments for the computation. * x ([Tensor](../../tensor.md#max.tensor.Tensor))
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
### `weight` {#max.nn.norm.rms_norm.RMSNorm.weight} > weight: [Tensor](../../tensor.md#max.tensor.Tensor) ## `rms_norm()` {#max.nn.norm.rms_norm.rms_norm} > max.nn.norm.rms\_norm.rms\_norm(x, weight, eps, weight\_offset=0.0, multiply\_before\_cast=False) Applies Root Mean Square layer normalization to an input tensor. See
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – The input tensor * weight ([Tensor](../../tensor.md#max.tensor.Tensor)) – The weights for the normalization * eps ([float](https://docs.python.org/3/library/functions.html#float)) – A value added to the denominator of the normalization for numerical stability * weight\_offset ([float](https://docs.python.org/3/library/functions.html#float)) – A value added to the weights before normalization. Typically 1 for Gemma-like normalization and 0 otherwise. * multiply\_before\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to multiply before or after casting to the output dtype. Typically True for Gemma-like normalization and False otherwise.
**Returns:**
A layer-normalized tensor with the same shape and type as x.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
--- ## RotaryEmbedding ## `RotaryEmbedding` {#max.nn.rope.RotaryEmbedding} > class max.nn.rope.RotaryEmbedding(weight: [max.tensor.Tensor](../../tensor.md#max.tensor.Tensor))
**Parameters:**
weight ([Tensor](../../tensor.md#max.tensor.Tensor))
### `dim` {#max.nn.rope.RotaryEmbedding.dim} > property dim: [int](https://docs.python.org/3/library/functions.html#int) ### `forward()` {#max.nn.rope.RotaryEmbedding.forward} > forward(x, start\_pos=0) Applies rotary positional embeddings (RoPE) to x. seq\_len is inferred from the shape of x.
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – Activation tensor with shape (batch, seq\_len, n\_kv\_heads, head\_dim). x is interpreted as a complex number valued tensor where the head\_dim dimension is alternating pairs of (real, imaginary) parts. * start\_pos ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – starting position of input tensor, defaults to 0 if None
**Returns:**
Input activation tensor with rotary positional embeddings applied and the same shape as x.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
### `max_sequence_length` {#max.nn.rope.RotaryEmbedding.max_sequence_length} > property max\_sequence\_length: [int](https://docs.python.org/3/library/functions.html#int) ### `weight` {#max.nn.rope.RotaryEmbedding.weight} > weight: [Tensor](../../tensor.md#max.tensor.Tensor) --- ## TransposedRotaryEmbedding ## `TransposedRotaryEmbedding` {#max.nn.rope.rope.TransposedRotaryEmbedding} > class max.nn.rope.rope.TransposedRotaryEmbedding(weight)
**Parameters:**
weight ([Tensor](../../tensor.md#max.tensor.Tensor))
### `forward()` {#max.nn.rope.rope.TransposedRotaryEmbedding.forward} > forward(x, start\_pos=0) Applies rotary positional embeddings (RoPE) to x. The representation of x is transposed within the final dimension compared to traditional RotaryEmbedding. seq\_len is inferred from the shape of x.
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – Activation tensor with shape (batch, seq\_len, n\_kv\_heads, head\_dim). x is interpreted as a complex number valued tensor where the first half of head\_dim are the real parts and the last half are the imaginary parts. * start\_pos ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – starting position of input tensor, defaults to 0 if None
**Returns:**
Input activation tensor with rotary positional embeddings applied and the same shape as x.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
--- ## rope Rotary positional embedding (RoPE) implementations. ## Classes * [`RotaryEmbedding`](/max/api/python/nn/rope/RotaryEmbedding): Standard rotary position embedding implementation. * [`TransposedRotaryEmbedding`](/max/api/python/nn/rope/TransposedRotaryEmbedding): RoPE with transposed tensor layout. :::note Note For legacy RoPE variants (DynamicRotaryEmbedding, YarnRotaryEmbedding, etc.), see [legacy/rotary\_embedding](/max/api/python/nn/legacy/rotary_embedding). ::: --- ## sequential (Nn) :::note Note This module contains both `Sequential` and `ModuleList` containers. For the legacy graph-based sequential container, see [legacy/sequential](/max/api/python/nn/legacy/sequential). ::: A Module for a sequence of tensor transformations. ## `ModuleList` {#max.nn.sequential.ModuleList} > class max.nn.sequential.ModuleList(iterable=(), /) A `Module` subclass which is locally a list container. `ModuleList` instances will use the stringified integer index of their submodules as the name of the module for the purposes of qualified paths. For example: ```python from max.nn import Linear, Sequential model = Sequential( Linear(5, 10), Linear(10, 5), ) assert dict(model.parameters).keys() == { "0.weight", "0.bias", "1.weight", "1.bias" } ``` ### `children` {#max.nn.sequential.ModuleList.children} > property children: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](module.md#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Iterates over the direct child modules of the `Module`.
**Yields:**
`(name, module)` pairs, where `name` is the attribute name of the child on the module.
## `Sequential` {#max.nn.sequential.Sequential} > class max.nn.sequential.Sequential(\*modules) A `Module` subclass which holds a sequence of unary modules. A unary `Module` is one whose `forward()` method has the signature: ```default def forward(self, x: Tensor) -> Tensor: ... ``` `Sequential` is itself a unary `Module`. Its `forward()` method computes the result of applying each of its child modules in sequence to its input. For example, this will apply a linear transformation up to a dimension of 10, apply a LayerNorm, and then apply a final linear transformation to reduce back to the input dimension of 5: ```python from max.tensor import Tensor from max.nn import Linear, Sequential model = Sequential( Linear(5, 10), Linear(10, 5), ) result = model(Tensor.ones([5])) assert result.shape == [5] ```
**Parameters:**
modules (T)
### `forward()` {#max.nn.sequential.Sequential.forward} > forward(x) Applies the contained modules in order. For example, this code creates a sequence of linear transformations which each increase the dimension of the input by 5. The input tensor must have dim 5. The intermediate applications will result in intermediate tensors of dim 10 and 15 respectively, and the final result will have dim 20: ```python from max.tensor import Tensor from max.nn import Linear, Sequential hidden_dims = [5, 10, 15, 20] model = Sequential(*( Linear(in_dim, out_dim) for in_dim, out_dim in zip(hidden_dims, hidden_dims[1:]) )) result = model(Tensor.ones([5])) assert result.shape == [20] ```
**Parameters:**
x ([Tensor](../tensor.md#max.tensor.Tensor)) – The input tensor.
**Returns:**
The result of iteratively applying each contained module in sequence.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
--- ## architectures ## `register_all_models()` {#max.pipelines.architectures.register_all_models} > max.pipelines.architectures.register\_all\_models() Register all built-in model architectures with the global registry. This function imports each supported model architecture module (Llama, Mistral, Qwen, Gemma, DeepSeek, etc.) and registers their `SupportedArchitecture` definitions with `PIPELINE_REGISTRY`. This function is called automatically when `max.pipelines` is imported, so you typically don’t need to call it manually. It uses an internal flag to ensure architectures are only registered once, making repeated calls safe but unnecessary.
**Return type:**
None
--- ## config Standardized configuration for Pipeline Inference. ## `AudioGenerationConfig` {#max.pipelines.lib.config.AudioGenerationConfig} > class max.pipelines.lib.config.AudioGenerationConfig(audio\_decoder, audio\_decoder\_weights='', chunk\_size=None, buffer=0, block\_causal=False, prepend\_prompt\_speech\_tokens=PrependPromptSpeechTokens.NEVER, prepend\_prompt\_speech\_tokens\_causal=False, run\_model\_test\_mode=False, prometheus\_metrics\_mode=PrometheusMetricsMode.INSTRUMENT\_ONLY, \*, config\_file=None, section\_name=None, max\_length=None, pipeline\_role=PipelineRole.PrefillAndDecode, max\_batch\_size=None, max\_queue\_size\_tg=None, min\_batch\_size\_tg=None, ep\_size=1, ce\_delay\_ms=0.0, enable\_prioritize\_first\_decode=False, enable\_chunked\_prefill=True, enable\_in\_flight\_batching=False, max\_num\_steps=-1, max\_batch\_input\_tokens=8192, enable\_echo=False, pool\_embeddings=True, chat\_template=None, use\_experimental\_kernels='false', use\_vendor\_blas='false', pdl\_level='0', custom\_architectures=\, zmq\_endpoint\_base=\, execute\_empty\_batches=False, max\_batch\_total\_tokens=None, device\_graph\_capture=False, force=False, kvcache\_ce\_watermark=0.95, enable\_overlap\_scheduler=False, use\_legacy\_module=True, defer\_resolve=False, model=\, draft\_model=None, sampling=\, profiling=\, lora=None, speculative=None, audio\_decoder\_config=\)
**Parameters:**
* audio\_decoder ([str](https://docs.python.org/3/library/stdtypes.html#str)) * audio\_decoder\_weights ([str](https://docs.python.org/3/library/stdtypes.html#str)) * chunk\_size ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) * buffer ([int](https://docs.python.org/3/library/functions.html#int)) * block\_causal ([bool](https://docs.python.org/3/library/functions.html#bool)) * prepend\_prompt\_speech\_tokens ([PrependPromptSpeechTokens](#max.pipelines.lib.config.PrependPromptSpeechTokens)) * prepend\_prompt\_speech\_tokens\_causal ([bool](https://docs.python.org/3/library/functions.html#bool)) * run\_model\_test\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) * prometheus\_metrics\_mode ([PrometheusMetricsMode](#max.pipelines.lib.config.PrometheusMetricsMode)) * config\_file ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * section\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) * pipeline\_role (PipelineRole) * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) * max\_queue\_size\_tg ([int](https://docs.python.org/3/library/functions.html#int) | None) * min\_batch\_size\_tg ([int](https://docs.python.org/3/library/functions.html#int) | None) * ep\_size ([int](https://docs.python.org/3/library/functions.html#int)) * ce\_delay\_ms ([float](https://docs.python.org/3/library/functions.html#float)) * enable\_prioritize\_first\_decode ([bool](https://docs.python.org/3/library/functions.html#bool)) * enable\_chunked\_prefill ([bool](https://docs.python.org/3/library/functions.html#bool)) * enable\_in\_flight\_batching ([bool](https://docs.python.org/3/library/functions.html#bool)) * max\_num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * max\_batch\_input\_tokens ([int](https://docs.python.org/3/library/functions.html#int)) * enable\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) * pool\_embeddings ([bool](https://docs.python.org/3/library/functions.html#bool)) * chat\_template (Path | None) * use\_experimental\_kernels ([str](https://docs.python.org/3/library/stdtypes.html#str)) * use\_vendor\_blas ([str](https://docs.python.org/3/library/stdtypes.html#str)) * pdl\_level ([str](https://docs.python.org/3/library/stdtypes.html#str)) * custom\_architectures ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) * zmq\_endpoint\_base ([str](https://docs.python.org/3/library/stdtypes.html#str)) * execute\_empty\_batches ([bool](https://docs.python.org/3/library/functions.html#bool)) * max\_batch\_total\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * device\_graph\_capture ([bool](https://docs.python.org/3/library/functions.html#bool)) * force ([bool](https://docs.python.org/3/library/functions.html#bool)) * kvcache\_ce\_watermark ([float](https://docs.python.org/3/library/functions.html#float)) * enable\_overlap\_scheduler ([bool](https://docs.python.org/3/library/functions.html#bool)) * use\_legacy\_module ([bool](https://docs.python.org/3/library/functions.html#bool)) * defer\_resolve ([bool](https://docs.python.org/3/library/functions.html#bool)) * model ([MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig)) * draft\_model ([MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig) | None) * sampling (SamplingConfig) * profiling (ProfilingConfig) * lora ([LoRAConfig](lora_config.md#max.pipelines.lib.lora_config.LoRAConfig) | None) * speculative (SpeculativeConfig | None) * audio\_decoder\_config ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)])
### `audio_decoder` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder} > audio\_decoder: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `audio_decoder_config` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder_config} > audio\_decoder\_config: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), Any] ### `audio_decoder_weights` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder_weights} > audio\_decoder\_weights: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `block_causal` {#max.pipelines.lib.config.AudioGenerationConfig.block_causal} > block\_causal: [bool](https://docs.python.org/3/library/functions.html#bool) ### `buffer` {#max.pipelines.lib.config.AudioGenerationConfig.buffer} > buffer: [int](https://docs.python.org/3/library/functions.html#int) ### `chunk_size` {#max.pipelines.lib.config.AudioGenerationConfig.chunk_size} > chunk\_size: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) ### `from_flags()` {#max.pipelines.lib.config.AudioGenerationConfig.from_flags} > classmethod from\_flags(audio\_flags, \*\*config\_flags) Builds an AudioGenerationConfig from audio CLI flags and config kwargs.
**Parameters:**
* audio\_flags ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)]) * config\_flags ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Return type:**
[AudioGenerationConfig](#max.pipelines.lib.config.AudioGenerationConfig)
### `model_config` {#max.pipelines.lib.config.AudioGenerationConfig.model_config} > model\_config: ClassVar\[ConfigDict] = {} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `model_post_init()` {#max.pipelines.lib.config.AudioGenerationConfig.model_post_init} > model\_post\_init(context, /) This function is meant to behave like a BaseModel method to initialise private attributes. It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance. * context (Any) – The context.
**Return type:**
None
### `prepend_prompt_speech_tokens` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens} > prepend\_prompt\_speech\_tokens: [PrependPromptSpeechTokens](#max.pipelines.lib.config.PrependPromptSpeechTokens) ### `prepend_prompt_speech_tokens_causal` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens_causal} > prepend\_prompt\_speech\_tokens\_causal: [bool](https://docs.python.org/3/library/functions.html#bool) ### `prometheus_metrics_mode` {#max.pipelines.lib.config.AudioGenerationConfig.prometheus_metrics_mode} > prometheus\_metrics\_mode: [PrometheusMetricsMode](#max.pipelines.lib.config.PrometheusMetricsMode) ## `PipelineConfig` {#max.pipelines.lib.config.PipelineConfig} > class max.pipelines.lib.config.PipelineConfig(\*, config\_file=None, section\_name=None, max\_length=None, pipeline\_role=PipelineRole.PrefillAndDecode, max\_batch\_size=None, max\_queue\_size\_tg=None, min\_batch\_size\_tg=None, ep\_size=1, ce\_delay\_ms=0.0, enable\_prioritize\_first\_decode=False, enable\_chunked\_prefill=True, enable\_in\_flight\_batching=False, max\_num\_steps=-1, max\_batch\_input\_tokens=8192, enable\_echo=False, pool\_embeddings=True, chat\_template=None, use\_experimental\_kernels='false', use\_vendor\_blas='false', pdl\_level='0', custom\_architectures=\, zmq\_endpoint\_base=\, execute\_empty\_batches=False, max\_batch\_total\_tokens=None, device\_graph\_capture=False, force=False, kvcache\_ce\_watermark=0.95, enable\_overlap\_scheduler=False, use\_legacy\_module=True, defer\_resolve=False, model=\, draft\_model=None, sampling=\, profiling=\, lora=None, speculative=None) Configuration for a pipeline. WIP - Once a PipelineConfig is fully initialized, it should be as immutable as possible (frozen=True). All underlying dataclass fields should have been initialized to their default values, be it user specified via some CLI flag, config file, environment variable, or internally set to a reasonable default.
**Parameters:**
* config\_file ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * section\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) * pipeline\_role (PipelineRole) * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) * max\_queue\_size\_tg ([int](https://docs.python.org/3/library/functions.html#int) | None) * min\_batch\_size\_tg ([int](https://docs.python.org/3/library/functions.html#int) | None) * ep\_size ([int](https://docs.python.org/3/library/functions.html#int)) * ce\_delay\_ms ([float](https://docs.python.org/3/library/functions.html#float)) * enable\_prioritize\_first\_decode ([bool](https://docs.python.org/3/library/functions.html#bool)) * enable\_chunked\_prefill ([bool](https://docs.python.org/3/library/functions.html#bool)) * enable\_in\_flight\_batching ([bool](https://docs.python.org/3/library/functions.html#bool)) * max\_num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * max\_batch\_input\_tokens ([int](https://docs.python.org/3/library/functions.html#int)) * enable\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) * pool\_embeddings ([bool](https://docs.python.org/3/library/functions.html#bool)) * chat\_template (Path | None) * use\_experimental\_kernels ([str](https://docs.python.org/3/library/stdtypes.html#str)) * use\_vendor\_blas ([str](https://docs.python.org/3/library/stdtypes.html#str)) * pdl\_level ([str](https://docs.python.org/3/library/stdtypes.html#str)) * custom\_architectures ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) * zmq\_endpoint\_base ([str](https://docs.python.org/3/library/stdtypes.html#str)) * execute\_empty\_batches ([bool](https://docs.python.org/3/library/functions.html#bool)) * max\_batch\_total\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) * device\_graph\_capture ([bool](https://docs.python.org/3/library/functions.html#bool)) * force ([bool](https://docs.python.org/3/library/functions.html#bool)) * kvcache\_ce\_watermark ([float](https://docs.python.org/3/library/functions.html#float)) * enable\_overlap\_scheduler ([bool](https://docs.python.org/3/library/functions.html#bool)) * use\_legacy\_module ([bool](https://docs.python.org/3/library/functions.html#bool)) * defer\_resolve ([bool](https://docs.python.org/3/library/functions.html#bool)) * model ([MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig)) * draft\_model ([MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig) | None) * sampling (SamplingConfig) * profiling (ProfilingConfig) * lora ([LoRAConfig](lora_config.md#max.pipelines.lib.lora_config.LoRAConfig) | None) * speculative (SpeculativeConfig | None)
### `ce_delay_ms` {#max.pipelines.lib.config.PipelineConfig.ce_delay_ms} > ce\_delay\_ms: [float](https://docs.python.org/3/library/functions.html#float) ### `chat_template` {#max.pipelines.lib.config.PipelineConfig.chat_template} > chat\_template: Path | [None](https://docs.python.org/3/library/constants.html#None) ### `configure_session()` {#max.pipelines.lib.config.PipelineConfig.configure_session} > configure\_session(session) Configure an InferenceSession with standard pipeline settings.
**Parameters:**
session ([InferenceSession](../engine.md#max.engine.InferenceSession))
**Return type:**
None
### `custom_architectures` {#max.pipelines.lib.config.PipelineConfig.custom_architectures} > custom\_architectures: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] ### `defer_resolve` {#max.pipelines.lib.config.PipelineConfig.defer_resolve} > defer\_resolve: [bool](https://docs.python.org/3/library/functions.html#bool) ### `device_graph_capture` {#max.pipelines.lib.config.PipelineConfig.device_graph_capture} > device\_graph\_capture: [bool](https://docs.python.org/3/library/functions.html#bool) ### `draft_model` {#max.pipelines.lib.config.PipelineConfig.draft_model} > draft\_model: [MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig) | [None](https://docs.python.org/3/library/constants.html#None) ### `enable_chunked_prefill` {#max.pipelines.lib.config.PipelineConfig.enable_chunked_prefill} > enable\_chunked\_prefill: [bool](https://docs.python.org/3/library/functions.html#bool) ### `enable_echo` {#max.pipelines.lib.config.PipelineConfig.enable_echo} > enable\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) ### `enable_in_flight_batching` {#max.pipelines.lib.config.PipelineConfig.enable_in_flight_batching} > enable\_in\_flight\_batching: [bool](https://docs.python.org/3/library/functions.html#bool) ### `enable_overlap_scheduler` {#max.pipelines.lib.config.PipelineConfig.enable_overlap_scheduler} > enable\_overlap\_scheduler: [bool](https://docs.python.org/3/library/functions.html#bool) ### `enable_prioritize_first_decode` {#max.pipelines.lib.config.PipelineConfig.enable_prioritize_first_decode} > enable\_prioritize\_first\_decode: [bool](https://docs.python.org/3/library/functions.html#bool) ### `ep_size` {#max.pipelines.lib.config.PipelineConfig.ep_size} > ep\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `execute_empty_batches` {#max.pipelines.lib.config.PipelineConfig.execute_empty_batches} > execute\_empty\_batches: [bool](https://docs.python.org/3/library/functions.html#bool) ### `force` {#max.pipelines.lib.config.PipelineConfig.force} > force: [bool](https://docs.python.org/3/library/functions.html#bool) ### `graph_quantization_encoding` {#max.pipelines.lib.config.PipelineConfig.graph_quantization_encoding} > property graph\_quantization\_encoding: [QuantizationEncoding](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) Converts the CLI encoding to a MAX graph quantization encoding.
**Returns:**
The graph quantization encoding corresponding to the CLI encoding.
### `kvcache_ce_watermark` {#max.pipelines.lib.config.PipelineConfig.kvcache_ce_watermark} > kvcache\_ce\_watermark: [float](https://docs.python.org/3/library/functions.html#float) ### `log_basic_config()` {#max.pipelines.lib.config.PipelineConfig.log_basic_config} > log\_basic\_config() Log minimal pipeline configuration information. Logs basic PipelineConfig options including model name, pipeline task, weight path, max\_batch\_size, max\_seq\_len, and reserved memory.
**Return type:**
None
### `log_pipeline_info()` {#max.pipelines.lib.config.PipelineConfig.log_pipeline_info} > log\_pipeline\_info() Logs comprehensive pipeline and KVCache configuration information. Retrieves all necessary information from self and the PIPELINE\_REGISTRY. Raises an error if architecture is not found (which should not happen after config resolution).
**Return type:**
None
### `lora` {#max.pipelines.lib.config.PipelineConfig.lora} > lora: [LoRAConfig](lora_config.md#max.pipelines.lib.lora_config.LoRAConfig) | [None](https://docs.python.org/3/library/constants.html#None) ### `max_batch_input_tokens` {#max.pipelines.lib.config.PipelineConfig.max_batch_input_tokens} > max\_batch\_input\_tokens: [int](https://docs.python.org/3/library/functions.html#int) ### `max_batch_size` {#max.pipelines.lib.config.PipelineConfig.max_batch_size} > max\_batch\_size: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `max_batch_total_tokens` {#max.pipelines.lib.config.PipelineConfig.max_batch_total_tokens} > max\_batch\_total\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `max_length` {#max.pipelines.lib.config.PipelineConfig.max_length} > max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `max_num_steps` {#max.pipelines.lib.config.PipelineConfig.max_num_steps} > max\_num\_steps: [int](https://docs.python.org/3/library/functions.html#int) ### `max_queue_size_tg` {#max.pipelines.lib.config.PipelineConfig.max_queue_size_tg} > max\_queue\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `min_batch_size_tg` {#max.pipelines.lib.config.PipelineConfig.min_batch_size_tg} > min\_batch\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `model` {#max.pipelines.lib.config.PipelineConfig.model} > model: [MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig) ### `model_config` {#max.pipelines.lib.config.PipelineConfig.model_config} > model\_config: ClassVar\[ConfigDict] = {} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `model_post_init()` {#max.pipelines.lib.config.PipelineConfig.model_post_init} > model\_post\_init(context, /) This function is meant to behave like a BaseModel method to initialise private attributes. It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance. * context (Any) – The context.
**Return type:**
None
### `pdl_level` {#max.pipelines.lib.config.PipelineConfig.pdl_level} > pdl\_level: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `pipeline_role` {#max.pipelines.lib.config.PipelineConfig.pipeline_role} > pipeline\_role: PipelineRole ### `pool_embeddings` {#max.pipelines.lib.config.PipelineConfig.pool_embeddings} > pool\_embeddings: [bool](https://docs.python.org/3/library/functions.html#bool) ### `profiling` {#max.pipelines.lib.config.PipelineConfig.profiling} > profiling: ProfilingConfig ### `resolve()` {#max.pipelines.lib.config.PipelineConfig.resolve} > resolve() Validates and resolves the config. Called after the config is initialized to ensure all config fields are in a valid state.
**Return type:**
None
### `retrieve_chat_template()` {#max.pipelines.lib.config.PipelineConfig.retrieve_chat_template} > retrieve\_chat\_template() Returns the chat template string, or None if not set.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str) | None
### `sampling` {#max.pipelines.lib.config.PipelineConfig.sampling} > sampling: SamplingConfig ### `speculative` {#max.pipelines.lib.config.PipelineConfig.speculative} > speculative: SpeculativeConfig | [None](https://docs.python.org/3/library/constants.html#None) ### `use_experimental_kernels` {#max.pipelines.lib.config.PipelineConfig.use_experimental_kernels} > use\_experimental\_kernels: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `use_legacy_module` {#max.pipelines.lib.config.PipelineConfig.use_legacy_module} > use\_legacy\_module: [bool](https://docs.python.org/3/library/functions.html#bool) ### `use_vendor_blas` {#max.pipelines.lib.config.PipelineConfig.use_vendor_blas} > use\_vendor\_blas: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `zmq_endpoint_base` {#max.pipelines.lib.config.PipelineConfig.zmq_endpoint_base} > zmq\_endpoint\_base: [str](https://docs.python.org/3/library/stdtypes.html#str) ## `PrependPromptSpeechTokens` {#max.pipelines.lib.config.PrependPromptSpeechTokens} > class max.pipelines.lib.config.PrependPromptSpeechTokens(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `NEVER` {#max.pipelines.lib.config.PrependPromptSpeechTokens.NEVER} > NEVER = 'never' Never prepend the prompt speech tokens sent to the audio decoder. ### `ONCE` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ONCE} > ONCE = 'once' Prepend the prompt speech tokens to the first block of the audio decoder. ### `ROLLING` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ROLLING} > ROLLING = 'rolling' Prepend the prompt speech tokens to the first block of the audio decoder, and to later blocks to reach the requested buffer size. ## `PrometheusMetricsMode` {#max.pipelines.lib.config.PrometheusMetricsMode} > class max.pipelines.lib.config.PrometheusMetricsMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `INSTRUMENT_ONLY` {#max.pipelines.lib.config.PrometheusMetricsMode.INSTRUMENT_ONLY} > INSTRUMENT\_ONLY = 'instrument\_only' Instrument metrics through the Prometheus client library, relying on the application to handle the metrics server. ### `LAUNCH_MULTIPROC_SERVER` {#max.pipelines.lib.config.PrometheusMetricsMode.LAUNCH_MULTIPROC_SERVER} > LAUNCH\_MULTIPROC\_SERVER = 'launch\_multiproc\_server' Launch a Prometheus server in multiprocess mode to report metrics. ### `LAUNCH_SERVER` {#max.pipelines.lib.config.PrometheusMetricsMode.LAUNCH_SERVER} > LAUNCH\_SERVER = 'launch\_server' Launch a Prometheus server to handle metrics requests. --- ## core ## `PixelContext` {#max.pipelines.core.PixelContext} > class max.pipelines.core.PixelContext(\*, tokens, request\_id=\, model\_name='', mask=None, tokens\_2=None, negative\_tokens=None, negative\_tokens\_2=None, extra\_params=\, timesteps=\, sigmas=\, latents=\, latent\_image\_ids=\, height=1024, width=1024, num\_inference\_steps=50, guidance\_scale=3.5, guidance=None, true\_cfg\_scale=1.0, num\_warmup\_steps=0, num\_images\_per\_prompt=1, status=GenerationStatus.ACTIVE) A model-ready context for image/video generation requests. Per the design doc, this class contains only numeric data that the model will execute against. User-facing strings (prompt, negative\_prompt) are consumed during tokenization and do not appear here. All preprocessing is performed by PixelGenerationTokenizer.new\_context(): * Prompt tokenization -> tokens field * Negative prompt tokenization -> negative\_tokens field * Timestep schedule computation -> timesteps field * Initial noise generation -> latents field
**Parameters:**
* tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) – Tokenized prompt IDs (TokenBuffer). * request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) – A unique identifier for this generation request. * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Name of the model being used. * mask ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[bool](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool)]] | None) * tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) * negative\_tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) – Tokenized negative prompt IDs (TokenBuffer). * negative\_tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) * extra\_params ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) * timesteps ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) – Precomputed timestep schedule for denoising. * sigmas ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * latents ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) – Precomputed initial noise (latents). * latent\_image\_ids ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * height ([int](https://docs.python.org/3/library/functions.html#int)) – Height of the generated image/video in pixels. * width ([int](https://docs.python.org/3/library/functions.html#int)) – Width of the generated image/video in pixels. * num\_inference\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of denoising steps. * guidance\_scale ([float](https://docs.python.org/3/library/functions.html#float)) – Guidance scale for classifier-free guidance. * guidance ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | None) * true\_cfg\_scale ([float](https://docs.python.org/3/library/functions.html#float)) * num\_warmup\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * num\_images\_per\_prompt ([int](https://docs.python.org/3/library/functions.html#int)) – Number of images/videos to generate per prompt. * status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
### `compute_num_available_steps()` {#max.pipelines.core.PixelContext.compute_num_available_steps} > compute\_num\_available\_steps(max\_seq\_len) Compute number of available steps for scheduler compatibility. For image and video generation, this returns the number of inference steps.
**Parameters:**
max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `extra_params` {#max.pipelines.core.PixelContext.extra_params} > extra\_params: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Model-specific numeric parameters (e.g., cfg\_normalization values). ### `guidance` {#max.pipelines.core.PixelContext.guidance} > guidance: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `guidance_scale` {#max.pipelines.core.PixelContext.guidance_scale} > guidance\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 3.5 ### `height` {#max.pipelines.core.PixelContext.height} > height: [int](https://docs.python.org/3/library/functions.html#int) = 1024 ### `is_done` {#max.pipelines.core.PixelContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the request has completed generation. ### `latent_image_ids` {#max.pipelines.core.PixelContext.latent_image_ids} > latent\_image\_ids: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed latent image IDs for generation. ### `latents` {#max.pipelines.core.PixelContext.latents} > latents: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed initial noise (latents) for generation. ### `mask` {#max.pipelines.core.PixelContext.mask} > mask: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[bool](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool)]] | [None](https://docs.python.org/3/library/constants.html#None) = None Mask for text encoder’s attention. ### `model_name` {#max.pipelines.core.PixelContext.model_name} > model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' ### `negative_tokens` {#max.pipelines.core.PixelContext.negative_tokens} > negative\_tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Negative tokens for primary encoder. ### `negative_tokens_2` {#max.pipelines.core.PixelContext.negative_tokens_2} > negative\_tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Negative tokens for secondary encoder. None for single-encoder models. ### `num_images_per_prompt` {#max.pipelines.core.PixelContext.num_images_per_prompt} > num\_images\_per\_prompt: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `num_inference_steps` {#max.pipelines.core.PixelContext.num_inference_steps} > num\_inference\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 50 ### `num_warmup_steps` {#max.pipelines.core.PixelContext.num_warmup_steps} > num\_warmup\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 0 ### `request_id` {#max.pipelines.core.PixelContext.request_id} > request\_id: [RequestID](../interfaces.md#max.interfaces.RequestID) ### `reset()` {#max.pipelines.core.PixelContext.reset} > reset() Resets the context’s state.
**Return type:**
None
### `sigmas` {#max.pipelines.core.PixelContext.sigmas} > sigmas: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed sigmas schedule for denoising. ### `status` {#max.pipelines.core.PixelContext.status} > status: [GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus) = 'active' ### `timesteps` {#max.pipelines.core.PixelContext.timesteps} > timesteps: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed timesteps schedule for denoising. ### `to_generation_output()` {#max.pipelines.core.PixelContext.to_generation_output} > to\_generation\_output() Convert this context to a GenerationOutput object.
**Return type:**
[GenerationOutput](../interfaces.md#max.interfaces.GenerationOutput)
### `tokens` {#max.pipelines.core.PixelContext.tokens} > tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) Primary encoder tokens. ### `tokens_2` {#max.pipelines.core.PixelContext.tokens_2} > tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Secondary encoder tokens. None for single-encoder models. ### `true_cfg_scale` {#max.pipelines.core.PixelContext.true_cfg_scale} > true\_cfg\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 1.0 ### `update()` {#max.pipelines.core.PixelContext.update} > update(latents) Update the context with newly generated latents/image data.
**Parameters:**
latents ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
**Return type:**
None
### `width` {#max.pipelines.core.PixelContext.width} > width: [int](https://docs.python.org/3/library/functions.html#int) = 1024 ## `TTSContext` {#max.pipelines.core.TTSContext} > class max.pipelines.core.TTSContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None, audio\_prompt\_tokens=\, buffer\_speech\_tokens=None, audio\_buffer=None, prev\_samples\_beyond\_offset=0, streaming=False, \_speech\_token\_size=128, \_speech\_token\_end\_idx=0, \_speech\_tokens=\, decoded\_index=0, \_block\_counter=0, \_arrival\_time=\, audio\_generation\_status=GenerationStatus.ACTIVE) A context for Text-to-Speech (TTS) model inference. This class extends TextContext to handle speech token generation and management. It maintains buffers for audio prompt tokens and generated speech tokens, along with tracking indices for decoding progress.
**Parameters:**
* max\_length ([int](https://docs.python.org/3/library/functions.html#int)) * tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) * request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) * eos\_token\_ids ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) * eos\_sequences ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]]) * log\_probabilities ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) * ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool)) * json\_schema ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * sampling\_params ([SamplingParams](../interfaces.md#max.interfaces.SamplingParams)) * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * \_matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any) | None) * status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus)) * \_log\_probabilities\_data ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities)]) * \_is\_initial\_prompt ([bool](https://docs.python.org/3/library/functions.html#bool)) * \_draft\_offset ([int](https://docs.python.org/3/library/functions.html#int)) * target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * audio\_prompt\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Array of input audio prompt tokens used for voice cloning * buffer\_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None) * audio\_buffer ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None) * prev\_samples\_beyond\_offset ([int](https://docs.python.org/3/library/functions.html#int)) * streaming ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the request is streaming the audio to client * \_speech\_token\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Size of the speech token buffer, defaults to SPEECH\_TOKEN\_audio\_chunk\_size * \_speech\_token\_end\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index marking the end of valid speech tokens * \_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Buffer containing the generated speech tokens * decoded\_index ([int](https://docs.python.org/3/library/functions.html#int)) * \_block\_counter ([int](https://docs.python.org/3/library/functions.html#int)) – Counter tracking number of speech token blocks generated * \_arrival\_time ([float](https://docs.python.org/3/library/functions.html#float)) * audio\_generation\_status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
### `audio_buffer` {#max.pipelines.core.TTSContext.audio_buffer} > audio\_buffer: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `audio_generation_status` {#max.pipelines.core.TTSContext.audio_generation_status} > audio\_generation\_status: [GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus) = 'active' ### `audio_prompt_tokens` {#max.pipelines.core.TTSContext.audio_prompt_tokens} > audio\_prompt\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] ### `block_counter` {#max.pipelines.core.TTSContext.block_counter} > property block\_counter: [int](https://docs.python.org/3/library/functions.html#int) ### `buffer_speech_tokens` {#max.pipelines.core.TTSContext.buffer_speech_tokens} > buffer\_speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `decoded_index` {#max.pipelines.core.TTSContext.decoded_index} > decoded\_index: [int](https://docs.python.org/3/library/functions.html#int) = 0 ### `is_done` {#max.pipelines.core.TTSContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `next_speech_tokens()` {#max.pipelines.core.TTSContext.next_speech_tokens} > next\_speech\_tokens(audio\_chunk\_size=None, buffer=None) Returns a chunk of the next unseen speech tokens. Calling this function will not update the index of the last seen token. This must be done by setting decoded\_index after the chunk is processed.
**Parameters:**
* audio\_chunk\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of speech tokens to return. * buffer ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of previous speech tokens to pass to the audio decoder on each generation step.
**Returns:**
A tuple of (chunk of speech tokens, buffer).
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]], [int](https://docs.python.org/3/library/functions.html#int)]
### `prev_samples_beyond_offset` {#max.pipelines.core.TTSContext.prev_samples_beyond_offset} > prev\_samples\_beyond\_offset: [int](https://docs.python.org/3/library/functions.html#int) = 0 ### `speech_tokens` {#max.pipelines.core.TTSContext.speech_tokens} > property speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] ### `streaming` {#max.pipelines.core.TTSContext.streaming} > streaming: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `update_speech_tokens()` {#max.pipelines.core.TTSContext.update_speech_tokens} > update\_speech\_tokens(new\_tokens) Updates the next\_tokens
**Parameters:**
new\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
**Return type:**
None
## `TextAndVisionContext` {#max.pipelines.core.TextAndVisionContext} > class max.pipelines.core.TextAndVisionContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None, vision\_token\_ids, images=\, extra\_model\_args=\) A base class for model context, specifically for Vision model variants. For example: ```default - = 97 - = 98 - = 99 ``` Token array: ```default - idx: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 ] - token_ids: [ 51 52 53 54 97 98 98 98 98 99 55 56 57 58 97 98 98 98 98 99 59 60 61 62 ] ^-- img0 --^ ^-- img1 --^ ^ start_idx=11 (image_idx=1) ``` Then we would have: ```default - ImageMetadata(start_idx=5, end_idx=9, ...) # img0 - ImageMetadata(start_idx=15, end_idx=19, ...) # img1 ``` These image ranges should be non-overlapping. The image\_idx is determined based on the value of start\_idx. It is the idx of the first image that is not yet encoded. For example in the above diagram when start\_idx=11, this implies that image\_idx=1. Currently we restrict start\_idx and current\_position from being in the middle of an image! This is verified in \_validate\_state methods that are called before and after mutating methods like \_bump\_token\_indices.
**Parameters:**
* max\_length ([int](https://docs.python.org/3/library/functions.html#int)) * tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) * request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) * eos\_token\_ids ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) * eos\_sequences ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]]) * log\_probabilities ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) * ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool)) * json\_schema ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * sampling\_params ([SamplingParams](../interfaces.md#max.interfaces.SamplingParams)) * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * \_matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any) | None) * status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus)) * \_log\_probabilities\_data ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities)]) * \_is\_initial\_prompt ([bool](https://docs.python.org/3/library/functions.html#bool)) * \_draft\_offset ([int](https://docs.python.org/3/library/functions.html#int)) * target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * vision\_token\_ids ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * images ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](../interfaces.md#max.interfaces.ImageMetadata)]) * extra\_model\_args ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
### `compute_image_aligned_idx()` {#max.pipelines.core.TextAndVisionContext.compute_image_aligned_idx} > compute\_image\_aligned\_idx(idx) Possibly aligns a index value downward if it lies in the middle of an image.
**Parameters:**
idx ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `extra_model_args` {#max.pipelines.core.TextAndVisionContext.extra_model_args} > extra\_model\_args: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] Extra model arguments for the vision model. These are model specific arguments. ### `image_idx` {#max.pipelines.core.TextAndVisionContext.image_idx} > property image\_idx: [int](https://docs.python.org/3/library/functions.html#int) Index of the next unencoded image in the prompt. ### `images` {#max.pipelines.core.TextAndVisionContext.images} > images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](../interfaces.md#max.interfaces.ImageMetadata)] Metadata about each image in the prompt. ### `needs_vision_encoding` {#max.pipelines.core.TextAndVisionContext.needs_vision_encoding} > property needs\_vision\_encoding: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether vision encoding is needed for this context. ### `next_images` {#max.pipelines.core.TextAndVisionContext.next_images} > property next\_images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](../interfaces.md#max.interfaces.ImageMetadata)] Returns the images that are not yet encoded. ### `update()` {#max.pipelines.core.TextAndVisionContext.update} > update(new\_token, log\_probabilities=None) Updates the next\_tokens and extends existing tokens to include all generated tokens.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities ([LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities) | None)
**Return type:**
None
### `vision_token_ids` {#max.pipelines.core.TextAndVisionContext.vision_token_ids} > vision\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] The value of the \ special token. The reason this is a list is primarily due to Pixtral which also has a image\_break\_token\_id. ## `TextContext` {#max.pipelines.core.TextContext} > class max.pipelines.core.TextContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None) A base class for model context, specifically for Text model variants. This class manages the state and processing of text generation, including token management, caching, and generation parameters.
**Parameters:**
* max\_length ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum allowed length of the generated sequence * tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) – NumPy array containing the token IDs * request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) – A unique identifier for this sequence. * eos\_token\_ids ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Set of token IDs that indicate end of sequence * eos\_sequences ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]]) * log\_probabilities ([int](https://docs.python.org/3/library/functions.html#int)) – Whether to return token log probabilities * log\_probabilities\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to return log probabilities for prompt tokens * ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to ignore end of sequence tokens and continue generating * json\_schema ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional JSON schema for structured output * sampling\_params ([SamplingParams](../interfaces.md#max.interfaces.SamplingParams)) – Parameters controlling the token sampling strategy * model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * \_matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any) | None) * status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus)) * \_log\_probabilities\_data ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities)]) – Token log probabilities data * \_is\_initial\_prompt ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether this is the initial prompt encoding * \_draft\_offset ([int](https://docs.python.org/3/library/functions.html#int)) – Offset for draft decoding * target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional target endpoint identifier for routing requests
### `apply_processing_offset()` {#max.pipelines.core.TextContext.apply_processing_offset} > apply\_processing\_offset(offset)
**Parameters:**
offset ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `compute_num_available_steps()` {#max.pipelines.core.TextContext.compute_num_available_steps} > compute\_num\_available\_steps(max\_seq\_len) Compute the max number of steps we can execute for a given context without exceeding the max\_seq\_len.
**Parameters:**
max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `eos_sequences` {#max.pipelines.core.TextContext.eos_sequences} > eos\_sequences: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]] ### `eos_token_ids` {#max.pipelines.core.TextContext.eos_token_ids} > eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] ### `get_min_token_logit_mask()` {#max.pipelines.core.TextContext.get_min_token_logit_mask} > get\_min\_token\_logit\_mask(num\_steps) Returns a set of indices for the tokens in the output that should be masked. This is primarily used for the min\_tokens setting, where we mask eos tokens in the logits to avoid generating them before we reach min\_tokens.
**Returns:**
A set of indices for the tokens in the output that should be masked.
**Parameters:**
num\_steps ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int32]]]
### `ignore_eos` {#max.pipelines.core.TextContext.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `is_done` {#max.pipelines.core.TextContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_initial_prompt` {#max.pipelines.core.TextContext.is_initial_prompt} > property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool) Returns true if the context has not been updated with tokens. ### `json_schema` {#max.pipelines.core.TextContext.json_schema} > json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `jump_ahead()` {#max.pipelines.core.TextContext.jump_ahead} > jump\_ahead(new\_token) Updates the token array, while ensuring the new token is returned to the user.
**Parameters:**
new\_token ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `log_probabilities` {#max.pipelines.core.TextContext.log_probabilities} > log\_probabilities: [int](https://docs.python.org/3/library/functions.html#int) = 0 ### `log_probabilities_echo` {#max.pipelines.core.TextContext.log_probabilities_echo} > log\_probabilities\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `matcher` {#max.pipelines.core.TextContext.matcher} > property matcher: LLMatcher | [None](https://docs.python.org/3/library/constants.html#None) ### `max_length` {#max.pipelines.core.TextContext.max_length} > max\_length: [int](https://docs.python.org/3/library/functions.html#int) ### `min_tokens` {#max.pipelines.core.TextContext.min_tokens} > property min\_tokens: [int](https://docs.python.org/3/library/functions.html#int) The minimum number of new tokens to generate. ### `model_name` {#max.pipelines.core.TextContext.model_name} > model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' ### `realize_future_token()` {#max.pipelines.core.TextContext.realize_future_token} > realize\_future\_token(new\_token, log\_probabilities=None) Overwrite the placeholder future token with the actual token. This is primarily used for overlap scheduling.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities ([LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities) | None)
**Return type:**
None
### `request_id` {#max.pipelines.core.TextContext.request_id} > request\_id: [RequestID](../interfaces.md#max.interfaces.RequestID) ### `reset()` {#max.pipelines.core.TextContext.reset} > reset() Resets the context’s state by combining all tokens into a new prompt.
**Return type:**
None
### `sampling_params` {#max.pipelines.core.TextContext.sampling_params} > sampling\_params: [SamplingParams](../interfaces.md#max.interfaces.SamplingParams) ### `set_matcher()` {#max.pipelines.core.TextContext.set_matcher} > set\_matcher(matcher)
**Parameters:**
matcher (LLMatcher)
**Return type:**
None
### `status` {#max.pipelines.core.TextContext.status} > status: [GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus) = 'active' ### `target_endpoint` {#max.pipelines.core.TextContext.target_endpoint} > target\_endpoint: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `to_generation_output()` {#max.pipelines.core.TextContext.to_generation_output} > to\_generation\_output() Get completion tokens that are ready to be returned to the user. This method retrieves tokens that have been generated but not yet delivered to the user, along with their associated log probability data.
**Returns:**
The completion tokens and their associated log probabilities, if available.
**Return type:**
[TextGenerationOutput](../interfaces.md#max.interfaces.TextGenerationOutput)
### `tokens` {#max.pipelines.core.TextContext.tokens} > tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) ### `update()` {#max.pipelines.core.TextContext.update} > update(new\_token, log\_probabilities=None) Updates the next\_tokens and extends existing tokens to include all generated tokens.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) * log\_probabilities ([LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities) | None)
**Return type:**
None
### `update_with_future_token()` {#max.pipelines.core.TextContext.update_with_future_token} > update\_with\_future\_token() Append a placeholder future token to the generated tokens. This is primarily used for overlap scheduling.
**Return type:**
None
## `reserve_token_space_for_batch()` {#max.pipelines.core.reserve_token_space_for_batch} > max.pipelines.core.reserve\_token\_space\_for\_batch(batch, num\_tokens) Temporarily reserves token space for each context in a batch by incrementing the \_active\_idx and \_end\_idx attributes by num\_tokens for the duration of the context. These indices are restored to their original values upon exit. :param batch: List of TextContext objects to reserve space for. :param num\_tokens: Number of tokens to reserve for each context.
**Yields:**
None
**Parameters:**
* batch ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](#max.pipelines.core.TextContext)]) * num\_tokens ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[Iterator](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterator)\[None]
## `validate_aspect_ratio_args()` {#max.pipelines.core.validate_aspect_ratio_args} > max.pipelines.core.validate\_aspect\_ratio\_args(context) Validates that required aspect ratio arguments are present for vision input.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If required aspect ratio arguments are missing.
**Return type:**
None
## `validate_image_grid_thw_args()` {#max.pipelines.core.validate_image_grid_thw_args} > max.pipelines.core.validate\_image\_grid\_thw\_args(context) Validates that image\_grid\_thw is present when vision encoding is needed.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If image\_grid\_thw is missing from extra\_model\_args when vision encoding is needed.
**Return type:**
None
## `validate_image_shape_5d()` {#max.pipelines.core.validate_image_shape_5d} > max.pipelines.core.validate\_image\_shape\_5d(context) Validates that images have the expected 5-dimensional shape.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the image shape is not 5-dimensional.
**Return type:**
None
## `validate_initial_prompt_has_image()` {#max.pipelines.core.validate_initial_prompt_has_image} > max.pipelines.core.validate\_initial\_prompt\_has\_image(context) Validates that initial prompts contain an image for vision models.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the initial prompt doesn’t contain an image.
**Return type:**
None
## `validate_only_one_image()` {#max.pipelines.core.validate_only_one_image} > max.pipelines.core.validate\_only\_one\_image(context) Validates that at most one image is provided in the context.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If more than one image is provided.
**Return type:**
None
## `validate_requires_vision_context()` {#max.pipelines.core.validate_requires_vision_context} > max.pipelines.core.validate\_requires\_vision\_context(context) Validates that the context is a TextAndVisionContext.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the context is not a TextAndVisionContext.
**Return type:**
None
## `validate_vision_position_ids()` {#max.pipelines.core.validate_vision_position_ids} > max.pipelines.core.validate\_vision\_position\_ids(context) Validates that vision\_position\_ids is present when vision encoding is needed.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If vision\_position\_ids is missing from extra\_model\_args when vision encoding is needed.
**Return type:**
None
--- ## hf_utils Utilities for interacting with Hugging Face Files/Repos. ## `HuggingFaceRepo` {#max.pipelines.lib.hf_utils.HuggingFaceRepo} > class max.pipelines.lib.hf\_utils.HuggingFaceRepo(repo\_id, revision='main', trust\_remote\_code=False, repo\_type=None) Handle for interacting with a Hugging Face repository (remote or local).
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) * revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) * trust\_remote\_code ([bool](https://docs.python.org/3/library/functions.html#bool)) * repo\_type (RepoType | None)
### `encoding_for_file()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.encoding_for_file} > encoding\_for\_file(file) Infers the supported encoding for a given weight file path.
**Parameters:**
file ([str](https://docs.python.org/3/library/stdtypes.html#str) | Path)
**Return type:**
SupportedEncoding
### `file_exists()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.file_exists} > file\_exists(filename) Returns whether the given file exists in the repo.
**Parameters:**
filename ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
### `files_for_encoding()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.files_for_encoding} > files\_for\_encoding(encoding, weights\_format=None) Returns paths to weight files for the given encoding (and optionally format).
**Parameters:**
* encoding (SupportedEncoding) * weights\_format ([WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat) | None)
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]]
### `formats_available` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.formats_available} > property formats\_available: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)] Returns the weight formats available in this repo. ### `info` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.info} > property info: ModelInfo Returns Hugging Face model info (online repos only). ### `repo_id` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_id} > repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str) The Hugging Face repo id. While it’s called repo\_id, it can be a HF remote or local path altogether. ### `repo_type` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_type} > repo\_type: RepoType | [None](https://docs.python.org/3/library/constants.html#None) = None The type of repo. This is inferred from the repo\_id. ### `revision` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.revision} > revision: [str](https://docs.python.org/3/library/stdtypes.html#str) = 'main' The revision to use for the repo. ### `size_of()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.size_of} > size\_of(filename) Returns file size in bytes for online repos, or None.
**Parameters:**
filename ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int) | None
### `supported_encodings` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.supported_encodings} > property supported\_encodings: [list](https://docs.python.org/3/library/stdtypes.html#list)\[SupportedEncoding] Returns encodings supported by this repo’s weight files. ### `trust_remote_code` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.trust_remote_code} > trust\_remote\_code: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether to trust remote code. ### `weight_files` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.weight_files} > property weight\_files: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]] Returns weight file paths grouped by format (safetensors, gguf). ## `download_weight_files()` {#max.pipelines.lib.hf_utils.download_weight_files} > max.pipelines.lib.hf\_utils.download\_weight\_files(huggingface\_model\_id, filenames, revision=None, force\_download=False, max\_workers=8) Downloads weight files for a Hugging Face model and returns local paths.
**Parameters:**
* huggingface\_model\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face model identifier, ie. modularai/Llama-3.1-8B-Instruct-GGUF * filenames ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – A list of file paths relative to the root of the Hugging Face repo. If files provided are available locally, download is skipped, and the local files are used. * revision ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The Hugging Face revision to use. If provided, we check our cache directly without needing to go to Hugging Face directly, saving a network call. * force\_download ([bool](https://docs.python.org/3/library/functions.html#bool)) – A boolean, indicating whether we should force the files to be redownloaded, even if they are already available in our local cache, or a provided path. * max\_workers ([int](https://docs.python.org/3/library/functions.html#int)) – The number of worker threads to concurrently download files.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]
## `generate_local_model_path()` {#max.pipelines.lib.hf_utils.generate_local_model_path} > max.pipelines.lib.hf\_utils.generate\_local\_model\_path(repo\_id, revision) Generate the local filesystem path where a Hugging Face model repo is cached. This function uses Hugging Face’s official snapshot\_download with local\_files\_only=True to resolve the local cache path for a model repository.
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face repository ID in the format “org/model” (e.g. “HuggingFaceTB/SmolLM2-135M”) * revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The specific model revision hash to use, typically from a repo lock file
**Returns:**
The absolute path to the cached model files for the specified revision.
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
**Raises:**
[FileNotFoundError](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) – If the model is not found in the local cache
## `is_diffusion_pipeline()` {#max.pipelines.lib.hf_utils.is_diffusion_pipeline} > max.pipelines.lib.hf\_utils.is\_diffusion\_pipeline(repo) Check if a Hugging Face repository is a diffusion pipeline. Diffusion pipelines typically have a model\_index.json file that describes the pipeline components.
**Parameters:**
repo ([HuggingFaceRepo](#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo to check.
**Returns:**
True if the repository appears to be a diffusion pipeline, False otherwise.
**Return type:**
[bool](https://docs.python.org/3/library/functions.html#bool)
## `try_to_load_from_cache()` {#max.pipelines.lib.hf_utils.try_to_load_from_cache} > max.pipelines.lib.hf\_utils.try\_to\_load\_from\_cache(repo\_id, filename, revision) Wrapper around `huggingface_hub.try_to_load_from_cache`; validates repo exists. `validate_hf_repo_access` is called first to ensure the repo exists.
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) * filename ([str](https://docs.python.org/3/library/stdtypes.html#str)) * revision ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str) | [Any](https://docs.python.org/3/library/typing.html#typing.Any) | None
## `validate_hf_repo_access()` {#max.pipelines.lib.hf_utils.validate_hf_repo_access} > max.pipelines.lib.hf\_utils.validate\_hf\_repo\_access(repo\_id, revision) Validates repository access and raises clear, user-friendly errors. Results are cached to avoid redundant Hugging Face API calls when the same repository is validated multiple times within a process.
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face repository ID to validate * revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The revision/branch to validate
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – With user-friendly error messages for various access issues
**Return type:**
None
--- ## pipelines The pipelines package provides end-to-end implementations for text generation, embeddings, audio generation, and speech processing that automatically convert Hugging Face models into performance-optimized MAX graphs. Each pipeline can be served via OpenAI-compatible endpoints for production deployment. ## Modules * [`architectures`](/max/api/python/pipelines/architectures) * [`config`](/max/api/python/pipelines/config) * [`core`](/max/api/python/pipelines/core) * [`hf_utils`](/max/api/python/pipelines/hf_utils) * [`interfaces`](/max/api/python/pipelines/interfaces) * [`lora_config`](/max/api/python/pipelines/lora_config) * [`model_config`](/max/api/python/pipelines/model_config) * [`pipeline`](/max/api/python/pipelines/pipeline) * [`registry`](/max/api/python/pipelines/registry) * [`sampling`](/max/api/python/pipelines/sampling) * [`tokenizer`](/max/api/python/pipelines/tokenizer) --- ## interfaces (Pipelines) Interfaces for MAX pipelines. ## `AlwaysSignalBuffersMixin` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin} > class max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin Bases: [`object`](https://docs.python.org/3/library/functions.html#object) Mixin for models that always require signal buffers. Use this for models that use VocabParallelEmbedding or other distributed components that always perform allreduce, even on single-device setups. Models using this mixin build graphs that always include signal buffer inputs, regardless of device count. This is typically because they use distributed embedding layers or other components that call allreduce operations unconditionally. ### `devices` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin.devices} > devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](../driver.md#max.driver.Device)] Device list that must be provided by the model class. ### `signal_buffers` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin.signal_buffers} > property signal\_buffers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] Override to always create signal buffers. Models using this mixin have distributed components that always perform allreduce, even for single-device setups. Therefore, signal buffers are always required to match the graph inputs. In compile-only mode (virtual device mode), returns an empty list to avoid GPU memory allocation which is not supported.
**Returns:**
List of signal buffer tensors, one per device, or empty list in compile-only mode.
## `ArchConfig` {#max.pipelines.lib.interfaces.ArchConfig} > class max.pipelines.lib.interfaces.ArchConfig(\*args, \*\*kwargs) Bases: [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol) Config for a model architecture. ### `get_max_seq_len()` {#max.pipelines.lib.interfaces.ArchConfig.get_max_seq_len} > get\_max\_seq\_len() Returns the default maximum sequence length for the model. Subclasses should determine whether this value can be overridden by setting the `--max-length` (`pipeline_config.max_length`) flag.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `initialize()` {#max.pipelines.lib.interfaces.ArchConfig.initialize} > classmethod initialize(pipeline\_config) Initialize the config from a PipelineConfig.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
Self
## `ArchConfigWithAttentionKVCache` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache} > class max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache(dtype, devices=\, cache\_dtype=None, kv\_cache=\, data\_parallel\_degree=1, user\_provided\_max\_length=None, huggingface\_config=None, \_kv\_params=None) Bases: [`ArchConfigWithKVCache`](#max.pipelines.lib.interfaces.ArchConfigWithKVCache), [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC) Predefined configuration for architectures that use attention KV cache blocks. Subclasses must define the following attributes: * num\_key\_value\_heads: int * head\_dim: int * num\_layers: int * model\_max\_seq\_len: int
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../graph/ops.md#max.graph.ops.DeviceRef)]) * cache\_dtype ([DType](../dtype.md#max.dtype.DType) | None) * kv\_cache (KVCacheConfig) * data\_parallel\_degree ([int](https://docs.python.org/3/library/functions.html#int)) * user\_provided\_max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) * huggingface\_config (AutoConfig | None) * \_kv\_params ([KVCacheParams](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams) | None)
### `cache_dtype` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.cache_dtype} > cache\_dtype: [DType](../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None The data type to use for the KV cache. ### `data_parallel_degree` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.data_parallel_degree} > data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) = 1 The data parallel degree to use when running the model. ### `devices` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.devices} > devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../graph/ops.md#max.graph.ops.DeviceRef)] The physical devices to use when running the model. ### `dtype` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.dtype} > dtype: [DType](../dtype.md#max.dtype.DType) The data type to use for the model. ### `get_kv_params()` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.get_kv_params} > get\_kv\_params() Returns the KV cache parameters for this architecture.
**Return type:**
[KVCacheParams](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)
### `get_max_seq_len()` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.get_max_seq_len} > get\_max\_seq\_len() Returns the maximum sequence length the model can process. Returns `max_length` if set, otherwise `model_max_seq_len`. Raises ValueError if `max_length` exceeds `model_max_seq_len`.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `head_dim` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.head_dim} > abstract property head\_dim: [int](https://docs.python.org/3/library/functions.html#int) Dimensionality of each attention head. ### `huggingface_config` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.huggingface_config} > huggingface\_config: AutoConfig | [None](https://docs.python.org/3/library/constants.html#None) = None ### `initialize()` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.initialize} > classmethod initialize(pipeline\_config) Initialize the config from a PipelineConfig.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
Self
### `kv_cache` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.kv_cache} > kv\_cache: KVCacheConfig The KV cache configuration to use when running the model. ### `model_max_seq_len` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.model_max_seq_len} > abstract property model\_max\_seq\_len: [int](https://docs.python.org/3/library/functions.html#int) The maximum sequence length that can be processed by the model. ### `num_key_value_heads` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.num_key_value_heads} > abstract property num\_key\_value\_heads: [int](https://docs.python.org/3/library/functions.html#int) Number of key-value heads to use for the KV cache. ### `num_layers` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.num_layers} > abstract property num\_layers: [int](https://docs.python.org/3/library/functions.html#int) Number of hidden layers in the model. ### `user_provided_max_length` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.user_provided_max_length} > user\_provided\_max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Override for the maximum sequence length. ## `ArchConfigWithKVCache` {#max.pipelines.lib.interfaces.ArchConfigWithKVCache} > class max.pipelines.lib.interfaces.ArchConfigWithKVCache(\*args, \*\*kwargs) Bases: [`ArchConfig`](#max.pipelines.lib.interfaces.ArchConfig), [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol) Config for a model architecture that uses a KV cache. ### `get_kv_params()` {#max.pipelines.lib.interfaces.ArchConfigWithKVCache.get_kv_params} > get\_kv\_params() KV cache parameters to use when running the model.
**Return type:**
[KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)
## `ComponentModel` {#max.pipelines.lib.interfaces.ComponentModel} > class max.pipelines.lib.interfaces.ComponentModel(config, encoding, devices, weights) Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC) Base interface for component models with weight-backed execution.
**Parameters:**
* config (Any) * encoding (SupportedEncoding) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](../driver.md#max.driver.Device)]) * weights ([Weights](../graph/weights.md#max.graph.weights.Weights))
### `load_model()` {#max.pipelines.lib.interfaces.ComponentModel.load_model} > abstract load\_model() Load and return a runtime model instance.
**Return type:**
[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[…], [Any](https://docs.python.org/3/library/typing.html#typing.Any)]
## `DiffusionPipeline` {#max.pipelines.lib.interfaces.DiffusionPipeline} > class max.pipelines.lib.interfaces.DiffusionPipeline(pipeline\_config, session, devices, weight\_paths, \*\*kwargs) Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC) Base class for diffusion pipelines. Subclasses must define components mapping component names to ComponentModel types.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * session ([InferenceSession](../engine.md#max.engine.InferenceSession)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](../driver.md#max.driver.Device)]) * weight\_paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) * kwargs (Any)
### `components` {#max.pipelines.lib.interfaces.DiffusionPipeline.components} > components: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [type](https://docs.python.org/3/library/functions.html#type)\[[ComponentModel](#max.pipelines.lib.interfaces.ComponentModel)]] | [None](https://docs.python.org/3/library/constants.html#None) = None ### `execute()` {#max.pipelines.lib.interfaces.DiffusionPipeline.execute} > abstract execute(model\_inputs, \*\*kwargs) Execute the pipeline with the given model inputs.
**Parameters:**
* model\_inputs ([PixelModelInputs](#max.pipelines.lib.interfaces.PixelModelInputs)) – Prepared model inputs from prepare\_inputs. * \*\*kwargs ([Any](https://docs.python.org/3/library/typing.html#typing.Any)) – Additional pipeline-specific execution parameters.
**Returns:**
Pipeline-specific output (e.g., generated images).
**Return type:**
[Any](https://docs.python.org/3/library/typing.html#typing.Any)
### `finalize_pipeline_config()` {#max.pipelines.lib.interfaces.DiffusionPipeline.finalize_pipeline_config} > classmethod finalize\_pipeline\_config(pipeline\_config) Hook for finalizing pipeline configuration. Override if needed.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
None
### `init_remaining_components()` {#max.pipelines.lib.interfaces.DiffusionPipeline.init_remaining_components} > abstract init\_remaining\_components() Initialize non-ComponentModel components (e.g., image processors).
**Return type:**
None
### `prepare_inputs()` {#max.pipelines.lib.interfaces.DiffusionPipeline.prepare_inputs} > abstract prepare\_inputs(context) Prepare inputs for the pipeline.
**Parameters:**
context ([PixelGenerationContext](../interfaces.md#max.interfaces.PixelGenerationContext))
**Return type:**
[PixelModelInputs](#max.pipelines.lib.interfaces.PixelModelInputs)
## `GenerateMixin` {#max.pipelines.lib.interfaces.GenerateMixin} > class max.pipelines.lib.interfaces.GenerateMixin(\*args, \*\*kwargs) Bases: [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol)\[`TextGenerationContextType`, `RequestType`] Protocol for pipelines that support text generation. ### `execute()` {#max.pipelines.lib.interfaces.GenerateMixin.execute} > execute(inputs) Executes the pipeline for the given inputs.
**Parameters:**
inputs ([TextGenerationInputs](../interfaces.md#max.interfaces.TextGenerationInputs)\[TextGenerationContextType])
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](../interfaces.md#max.interfaces.RequestID), [TextGenerationOutput](../interfaces.md#max.interfaces.TextGenerationOutput)]
### `generate()` {#max.pipelines.lib.interfaces.GenerateMixin.generate} > generate(prompts) Generates outputs for the given prompts.
**Parameters:**
prompts (RequestType | [list](https://docs.python.org/3/library/stdtypes.html#list)\[RequestType])
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationOutput](../interfaces.md#max.interfaces.TextGenerationOutput)]
### `generate_async()` {#max.pipelines.lib.interfaces.GenerateMixin.generate_async} > async generate\_async(prompts) Generates outputs asynchronously for the given prompts.
**Parameters:**
prompts (RequestType | [list](https://docs.python.org/3/library/stdtypes.html#list)\[RequestType])
**Return type:**
[Any](https://docs.python.org/3/library/typing.html#typing.Any)
### `kv_managers` {#max.pipelines.lib.interfaces.GenerateMixin.kv_managers} > property kv\_managers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[PagedKVCacheManager](../kv_cache/paged_kv_cache/cache_manager.md#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager)] Returns the KV cache managers for this pipeline. ### `pipeline_config` {#max.pipelines.lib.interfaces.GenerateMixin.pipeline_config} > property pipeline\_config: [PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig) Returns the pipeline configuration. ### `release()` {#max.pipelines.lib.interfaces.GenerateMixin.release} > release(request\_id) Releases resources for the given request.
**Parameters:**
request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID))
**Return type:**
None
### `tokenizer` {#max.pipelines.lib.interfaces.GenerateMixin.tokenizer} > property tokenizer: [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[TextGenerationContextType, [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]], RequestType] Returns the tokenizer for this pipeline. ## `InputKey` {#max.pipelines.lib.interfaces.InputKey} > class max.pipelines.lib.interfaces.InputKey(\*inputs) Bases: [`object`](https://docs.python.org/3/library/functions.html#object)
**Parameters:**
inputs ([Buffer](../driver.md#max.driver.Buffer))
## `KVCacheMixin` {#max.pipelines.lib.interfaces.KVCacheMixin} > class max.pipelines.lib.interfaces.KVCacheMixin(\*args, \*\*kwargs) Bases: [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol) ### `get_kv_params()` {#max.pipelines.lib.interfaces.KVCacheMixin.get_kv_params} > abstract classmethod get\_kv\_params(huggingface\_config, pipeline\_config, devices, kv\_cache\_config, cache\_dtype) Returns the KV cache params for the pipeline model.
**Parameters:**
* huggingface\_config (AutoConfig) * pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../graph/ops.md#max.graph.ops.DeviceRef)]) * kv\_cache\_config (KVCacheConfig) * cache\_dtype ([DType](../dtype.md#max.dtype.DType))
**Return type:**
[KVCacheParams](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)
### `load_kv_managers()` {#max.pipelines.lib.interfaces.KVCacheMixin.load_kv_managers} > load\_kv\_managers(kv\_params, max\_batch\_size, max\_seq\_len, session, available\_cache\_memory) Provided a PipelineConfig and InferenceSession, loads the KV manager.
**Parameters:**
* kv\_params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) – KV cache parameters. * max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum batch size of the model. * max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum sequence length of the model. * session ([InferenceSession](../engine.md#max.engine.InferenceSession)) – Inference session to compile and init the KV cache. * available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – Amount of memory available to the KV cache, in bytes.
**Returns:**
A single KV cache manager.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[PagedKVCacheManager](../kv_cache/paged_kv_cache/cache_manager.md#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager)]
## `ModelInputs` {#max.pipelines.lib.interfaces.ModelInputs} > class max.pipelines.lib.interfaces.ModelInputs Bases: [`object`](https://docs.python.org/3/library/functions.html#object) Base class for model inputs. Use this class to encapsulate inputs for your model; you may store any number of dataclass fields. The following example demonstrates how to create a custom inputs class: ```python class ReplitInputs(ModelInputs): tokens: Buffer input_row_offsets: Buffer def __init__(self, tokens: Buffer, input_row_offsets: Buffer): self.tokens = tokens self.input_row_offsets = input_row_offsets tokens = Buffer.zeros((1, 2, 3), DType.int64) input_row_offsets = Buffer.zeros((1, 1, 1), DType.int64) # Initialize inputs inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets) # Access tensors list(inputs) == [tokens, input_row_offsets] # Output: True ``` ### `hidden_states` {#max.pipelines.lib.interfaces.ModelInputs.hidden_states} > hidden\_states: [Buffer](../driver.md#max.driver.Buffer) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] | [None](https://docs.python.org/3/library/constants.html#None) = None Hidden states for a variable number of tokens per sequence. For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device. ### `kv_cache_inputs` {#max.pipelines.lib.interfaces.ModelInputs.kv_cache_inputs} > kv\_cache\_inputs: KVCacheInputs | [None](https://docs.python.org/3/library/constants.html#None) = None ### `lora_ids` {#max.pipelines.lib.interfaces.ModelInputs.lora_ids} > lora\_ids: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Buffer containing the LoRA ids. ### `lora_ranks` {#max.pipelines.lib.interfaces.ModelInputs.lora_ranks} > lora\_ranks: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Buffer containing the LoRA ranks ### `update()` {#max.pipelines.lib.interfaces.ModelInputs.update} > update(\*\*kwargs) Updates attributes from keyword arguments (only existing, non-None).
**Return type:**
None
## `ModelOutputs` {#max.pipelines.lib.interfaces.ModelOutputs} > class max.pipelines.lib.interfaces.ModelOutputs(logits: 'Buffer', next\_token\_logits: 'Buffer | None' = None, logit\_offsets: 'Buffer | None' = None, hidden\_states: 'Buffer | list\[Buffer] | None' = None) Bases: [`object`](https://docs.python.org/3/library/functions.html#object)
**Parameters:**
* logits ([Buffer](../driver.md#max.driver.Buffer)) * next\_token\_logits ([Buffer](../driver.md#max.driver.Buffer) | None) * logit\_offsets ([Buffer](../driver.md#max.driver.Buffer) | None) * hidden\_states ([Buffer](../driver.md#max.driver.Buffer) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] | None)
### `hidden_states` {#max.pipelines.lib.interfaces.ModelOutputs.hidden_states} > hidden\_states: [Buffer](../driver.md#max.driver.Buffer) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] | [None](https://docs.python.org/3/library/constants.html#None) = None Hidden states for a variable number of tokens per sequence. For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device. ### `logit_offsets` {#max.pipelines.lib.interfaces.ModelOutputs.logit_offsets} > logit\_offsets: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Offsets to access variable length logits for each sequence. ### `logits` {#max.pipelines.lib.interfaces.ModelOutputs.logits} > logits: [Buffer](../driver.md#max.driver.Buffer) Logits for a variable number of tokens per sequence. ### `next_token_logits` {#max.pipelines.lib.interfaces.ModelOutputs.next_token_logits} > next\_token\_logits: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Logits for just the next token. ## `PipelineModel` {#max.pipelines.lib.interfaces.PipelineModel} > class max.pipelines.lib.interfaces.PipelineModel(pipeline\_config, session, huggingface\_config, encoding, devices, kv\_cache\_config, weights, adapter, return\_logits, return\_hidden\_states=ReturnHiddenStates.NONE) Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC), [`Generic`](https://docs.python.org/3/library/typing.html#typing.Generic)\[`BaseContextType`] A pipeline model with setup, input preparation and execution methods.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * session ([InferenceSession](../engine.md#max.engine.InferenceSession)) * huggingface\_config (AutoConfig) * encoding (SupportedEncoding) * devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](../driver.md#max.driver.Device)]) * kv\_cache\_config (KVCacheConfig) * weights ([Weights](../graph/weights.md#max.graph.weights.Weights)) * adapter (WeightsAdapter | None) * return\_logits ([ReturnLogits](../nn/legacy/transformer/transformer.md#max.nn.legacy.transformer.transformer.ReturnLogits)) * return\_hidden\_states ([ReturnHiddenStates](../nn/legacy/transformer/transformer.md#max.nn.legacy.transformer.transformer.ReturnHiddenStates))
### `calculate_max_seq_len()` {#max.pipelines.lib.interfaces.PipelineModel.calculate_max_seq_len} > abstract classmethod calculate\_max\_seq\_len(pipeline\_config, huggingface\_config) Calculates the optimal max sequence length for the model. Models are expected to implement this method. The following example shows how to implement it for a Mistral model: ```python class MistralModel(PipelineModel): @classmethod def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int: try: return upper_bounded_default( upper_bound=huggingface_config.max_seq_len, default=pipeline_config.max_length, ) except ValueError as e: raise ValueError( "Unable to infer max_length for Mistral, the provided " f"max_length ({pipeline_config.max_length}) exceeds the " f"model's max_seq_len ({huggingface_config.max_seq_len})." ) from e ```
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Configuration for the pipeline. * huggingface\_config (AutoConfig) – Hugging Face model configuration.
**Returns:**
The maximum sequence length to use.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `compute_log_probabilities()` {#max.pipelines.lib.interfaces.PipelineModel.compute_log_probabilities} > compute\_log\_probabilities(session, model\_inputs, model\_outputs, next\_tokens, batch\_top\_n, batch\_echo) Optional method that can be overridden to compute log probabilities.
**Parameters:**
* session ([InferenceSession](../engine.md#max.engine.InferenceSession)) – Inference session to compute log probabilities within. * model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)) – Inputs to the model returned by prepare\_\*\_token\_inputs(). * model\_outputs ([ModelOutputs](#max.pipelines.lib.interfaces.ModelOutputs)) – Outputs returned by execute(). * next\_tokens ([Buffer](../driver.md#max.driver.Buffer)) – Sampled tokens. Should have shape=\[batch size] * batch\_top\_n ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Number of top log probabilities to return per input in the batch. For any element where top\_n == 0, the LogProbabilities is skipped. * batch\_echo ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[bool](https://docs.python.org/3/library/functions.html#bool)]) – Whether to include input tokens in the returned log probabilities.
**Returns:**
List of log probabilities.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities) | None]
### `dtype` {#max.pipelines.lib.interfaces.PipelineModel.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the model data type (from encoding or pipeline config). ### `estimate_activation_memory()` {#max.pipelines.lib.interfaces.PipelineModel.estimate_activation_memory} > classmethod estimate\_activation\_memory(pipeline\_config, huggingface\_config) Estimates the activation memory required for model execution. This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers. The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Pipeline configuration * huggingface\_config (AutoConfig) – Hugging Face model configuration
**Returns:**
Estimated activation memory in bytes
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `estimate_weights_size()` {#max.pipelines.lib.interfaces.PipelineModel.estimate_weights_size} > classmethod estimate\_weights\_size(pipeline\_config) Calculates the estimated memory consumption of our model.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `execute()` {#max.pipelines.lib.interfaces.PipelineModel.execute} > abstract execute(model\_inputs) Executes the graph with the given inputs.
**Parameters:**
model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)) – The model inputs to execute, containing tensors and any other required data for model execution.
**Returns:**
ModelOutputs containing the pipeline’s output tensors.
**Return type:**
[ModelOutputs](#max.pipelines.lib.interfaces.ModelOutputs)
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic. ### `execute_with_capture()` {#max.pipelines.lib.interfaces.PipelineModel.execute_with_capture} > execute\_with\_capture(model\_inputs, batch\_size) Executes the model with optional capture handling. Subclasses can override this to integrate device graph capture/replay.
**Parameters:**
* model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)) * batch\_size ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[ModelOutputs](#max.pipelines.lib.interfaces.ModelOutputs)
### `finalize_pipeline_config()` {#max.pipelines.lib.interfaces.PipelineModel.finalize_pipeline_config} > classmethod finalize\_pipeline\_config(pipeline\_config) Finalizes the pipeline configuration. This method is called after the pipeline configuration is resolved. It can be overridden to perform any finalization steps that are needed.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
None
### `lora_manager` {#max.pipelines.lib.interfaces.PipelineModel.lora_manager} > property lora\_manager: LoRAManager | [None](https://docs.python.org/3/library/constants.html#None) Returns the LoRA manager if LoRA is enabled, otherwise None. ### `pre_capture_execution_trace()` {#max.pipelines.lib.interfaces.PipelineModel.pre_capture_execution_trace} > pre\_capture\_execution\_trace(model\_inputs, batch\_size) Captures execution traces for device graph replay when enabled.
**Parameters:**
* model\_inputs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)]) * batch\_size ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
None
### `prepare_initial_token_inputs()` {#max.pipelines.lib.interfaces.PipelineModel.prepare_initial_token_inputs} > abstract prepare\_initial\_token\_inputs(replica\_batches, kv\_cache\_inputs=None, return\_n\_logits=1) Prepares the initial inputs to be passed to `.execute()`. The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and `kv_cache_inputs` (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.
**Parameters:**
* replica\_batches ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[BaseContextType]]) * kv\_cache\_inputs (KVCacheInputs | None) * return\_n\_logits ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)
### `prepare_next_token_inputs()` {#max.pipelines.lib.interfaces.PipelineModel.prepare_next_token_inputs} > abstract prepare\_next\_token\_inputs(next\_tokens, prev\_model\_inputs) Prepares the secondary inputs to be passed to .execute(). While prepare\_initial\_token\_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
**Parameters:**
* next\_tokens ([Buffer](../driver.md#max.driver.Buffer)) * prev\_model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs))
**Return type:**
[ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)
### `signal_buffers` {#max.pipelines.lib.interfaces.PipelineModel.signal_buffers} > property signal\_buffers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] Lazily initialize signal buffers for multi-GPU communication collectives. Signal buffers are only needed during model execution, not during compilation. By deferring their allocation, we avoid memory allocation in compile-only mode.
**Returns:**
List of signal buffer tensors, one per device for multi-device setups, or an empty list for single-device setups or compile-only mode.
## `PixelModelInputs` {#max.pipelines.lib.interfaces.PixelModelInputs} > class max.pipelines.lib.interfaces.PixelModelInputs(\*, tokens, tokens\_2=None, negative\_tokens=None, negative\_tokens\_2=None, extra\_params=\, timesteps=\, sigmas=\, latents=\, latent\_image\_ids=\, height=1024, width=1024, num\_inference\_steps=50, guidance\_scale=3.5, guidance=None, true\_cfg\_scale=1.0, num\_warmup\_steps=0, num\_images\_per\_prompt=1) Bases: [`object`](https://docs.python.org/3/library/functions.html#object) Common input container for pixel-generation models. Provides a consistent set of fields used across multiple pixel pipelines and models.
**Parameters:**
* tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) * tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) * negative\_tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) * negative\_tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) * extra\_params ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) * timesteps ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * sigmas ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * latents ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * latent\_image\_ids ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) * height ([int](https://docs.python.org/3/library/functions.html#int)) * width ([int](https://docs.python.org/3/library/functions.html#int)) * num\_inference\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * guidance\_scale ([float](https://docs.python.org/3/library/functions.html#float)) * guidance ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | None) * true\_cfg\_scale ([float](https://docs.python.org/3/library/functions.html#float)) * num\_warmup\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * num\_images\_per\_prompt ([int](https://docs.python.org/3/library/functions.html#int))
### `extra_params` {#max.pipelines.lib.interfaces.PixelModelInputs.extra_params} > extra\_params: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] A bag of model-specific numeric parameters not represented as explicit fields. Typical uses: * Architecture-specific knobs (e.g., cfg\_normalization arrays, scaling vectors) * Precomputed per-step values not worth standardizing across all models * Small numeric tensors that are easier to carry as named extras than formal fields Values are expected to be numpy arrays (ndarray) to keep the contract consistent, but you can relax this if your codebase needs non-array values. ### `from_context()` {#max.pipelines.lib.interfaces.PixelModelInputs.from_context} > classmethod from\_context(context) Build an instance from a context-like dict. Policy: * If a key is missing: the dataclass default applies automatically. * If a key is present with value None: treat as missing and substitute the class default (including subclass overrides).
**Parameters:**
context ([PixelGenerationContext](../interfaces.md#max.interfaces.PixelGenerationContext))
**Return type:**
[Self](https://docs.python.org/3/library/typing.html#typing.Self)
### `guidance` {#max.pipelines.lib.interfaces.PixelModelInputs.guidance} > guidance: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | [None](https://docs.python.org/3/library/constants.html#None) = None Optional guidance tensor. * Some pipelines precompute guidance weights/tensors (e.g., per-token weights, per-step weights). * None is meaningful here: it means “no explicit guidance tensor supplied”. * Unlike scalar fields, None is preserved (not replaced). ### `guidance_scale` {#max.pipelines.lib.interfaces.PixelModelInputs.guidance_scale} > guidance\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 3.5 Guidance scale for classifier-free guidance (CFG). * A higher value typically increases adherence to the prompt but can reduce diversity. * This is expected to be a real float (not None). * If a context provides guidance\_scale=None, from\_context() substitutes the default. ### `height` {#max.pipelines.lib.interfaces.PixelModelInputs.height} > height: [int](https://docs.python.org/3/library/functions.html#int) = 1024 Output height in pixels. * This is a required scalar (not None). * If a context provides height=None, from\_context() treats that as “not provided” and substitutes this default value (or a subclass override). ### `latent_image_ids` {#max.pipelines.lib.interfaces.PixelModelInputs.latent_image_ids} > latent\_image\_ids: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Optional latent image IDs / positional identifiers for latents. * Some pipelines attach per-latent identifiers for caching, routing, or conditioning. * Often used to avoid recomputation of image-id embeddings across steps. * If unused, it may remain empty. ### `latents` {#max.pipelines.lib.interfaces.PixelModelInputs.latents} > latents: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Initial latent noise tensor (or initial latent state). * For diffusion/flow models, this is typically random noise seeded per request. * Shape depends on model: commonly \[B, C, H/8, W/8] for image latents, or \[B, T, C, H/8, W/8] for video latents. * If your pipeline generates latents internally, you may leave it empty. (Model-specific subclasses can enforce non-empty via \_\_post\_init\_\_.) ### `negative_tokens` {#max.pipelines.lib.interfaces.PixelModelInputs.negative_tokens} > negative\_tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Negative prompt tokens for the primary encoder. Used for classifier-free guidance (CFG) or similar conditioning schemes. If your pipeline does not use negative prompts, leave as None. ### `negative_tokens_2` {#max.pipelines.lib.interfaces.PixelModelInputs.negative_tokens_2} > negative\_tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Negative prompt tokens for the secondary encoder (for dual-encoder models). If the model is single-encoder or you do not use negative prompts, leave as None. ### `num_images_per_prompt` {#max.pipelines.lib.interfaces.PixelModelInputs.num_images_per_prompt} > num\_images\_per\_prompt: [int](https://docs.python.org/3/library/functions.html#int) = 1 Number of images/videos to generate per prompt. * Commonly used for “same prompt, multiple samples” behavior. * Must be > 0. * For video generation, the naming may still be used for historical compatibility. ### `num_inference_steps` {#max.pipelines.lib.interfaces.PixelModelInputs.num_inference_steps} > num\_inference\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 50 Number of denoising/inference steps. * This is a required scalar (not None). * If a context provides num\_inference\_steps=None, from\_context() treats that as “not provided” and substitutes this default value (or a subclass override). ### `num_warmup_steps` {#max.pipelines.lib.interfaces.PixelModelInputs.num_warmup_steps} > num\_warmup\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 0 Number of warmup steps. * Used in some schedulers/pipelines to handle initial steps differently (e.g., scheduler stabilization, cache warmup, etc.). * Must be >= 0. ### `sigmas` {#max.pipelines.lib.interfaces.PixelModelInputs.sigmas} > sigmas: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed sigma schedule for denoising. * Usually a 1D float32 numpy array of length num\_inference\_steps corresponding to the noise level per step. * Some schedulers are sigma-based; others are timestep-based; some use both. * If unused, it may remain empty unless your model subclass requires it. ### `timesteps` {#max.pipelines.lib.interfaces.PixelModelInputs.timesteps} > timesteps: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] Precomputed denoising timestep schedule. * Usually a 1D float32 numpy array of length num\_inference\_steps (exact semantics depend on your scheduler). * If your pipeline precomputes the scheduler trajectory, you pass it here. * Some models may not require explicit timesteps; in that case it may remain empty. (Model-specific subclasses can enforce non-empty via \_\_post\_init\_\_.) ### `tokens` {#max.pipelines.lib.interfaces.PixelModelInputs.tokens} > tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) Primary encoder token buffer. This is the main prompt representation consumed by the model’s text encoder. Required for all models. ### `tokens_2` {#max.pipelines.lib.interfaces.PixelModelInputs.tokens_2} > tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None Secondary encoder token buffer (for dual-encoder models). Examples: architectures that have a second text encoder stream or pooled embeddings. If the model is single-encoder, leave as None. ### `true_cfg_scale` {#max.pipelines.lib.interfaces.PixelModelInputs.true_cfg_scale} > true\_cfg\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 1.0 “True CFG” scale used by certain pipelines/models. * Some architectures distinguish between the user-facing guidance\_scale and an internal scale applied to a different normalization or conditioning pathway. * Defaults to 1.0 for pipelines that do not use this feature. ### `width` {#max.pipelines.lib.interfaces.PixelModelInputs.width} > width: [int](https://docs.python.org/3/library/functions.html#int) = 1024 Output width in pixels. * This is a required scalar (not None). * If a context provides width=None, from\_context() treats that as “not provided” and substitutes this default value (or a subclass override). --- ## log_probabilities ## `compute_log_probabilities_ragged()` {#max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged} > max.pipelines.lib.log\_probabilities.compute\_log\_probabilities\_ragged(device, model, \*, input\_row\_offsets, logits, next\_token\_logits, tokens, sampled\_tokens, batch\_top\_n, batch\_echo) Computes the log probabilities for ragged model outputs.
**Parameters:**
* device ([Device](../driver.md#max.driver.Device)) – Device on which to do the bulk of the log probabilities computation. A small amount of computation still occurs on the host regardless of this setting. * model ([Model](../engine.md#max.engine.Model)) – A compiled version of a graph from the ‘log\_probabilities\_ragged\_graph’ function. * input\_row\_offsets ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Token offsets into token-indexed buffers, by batch index. Should have 1 more element than there are batches (batch n is token indices \[input\_row\_offsets\[n], input\_row\_offsets\[n+1])). * logits ([Buffer](../driver.md#max.driver.Buffer) | None) – (tokens, vocab\_dim) tensor full of tensor logits. Token dimension mapped to batches using input\_row\_offsets. May be omitted only if all ‘batch\_echo’ values are False. * next\_token\_logits ([Buffer](../driver.md#max.driver.Buffer)) – (batch\_dim, vocab\_dim) tensor full of tensor logits for the next token in each batch item. * tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – (total\_tokens,) flat token array for the batch; indices per batch given by input\_row\_offsets. * sampled\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – (batch\_dim,) tensor of sampled token per batch * batch\_top\_n ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Number of top log probabilities to return per input in the batch. For any element where top\_n == 0, the LogProbabilities is skipped. * batch\_echo ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[bool](https://docs.python.org/3/library/functions.html#bool)]) – Whether to include input tokens in the returned log probabilities.
**Returns:**
Computed log probabilities for each item in the batch.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities) | None]
## `log_probabilities_ragged_graph()` {#max.pipelines.lib.log_probabilities.log_probabilities_ragged_graph} > max.pipelines.lib.log\_probabilities.log\_probabilities\_ragged\_graph(device, \*, levels) Create a graph to compute log probabilities over ragged inputs. A model obtained by this graph is a required input to ‘compute\_log\_probabilities\_ragged’.
**Parameters:**
* device ([DeviceRef](../graph/type.md#max.graph.type.DeviceRef)) – The type of device this graph will need to run on. * levels ([int](https://docs.python.org/3/library/functions.html#int)) – log2(max\_k+1) for the desired maximum top-k you’d like to support. To support the OpenAI API maximum of 5 logprobs, use levels=3. Higher levels can be used to support higher k.
**Return type:**
[Graph](../graph/Graph.md#max.graph.Graph)
--- ## lora_config MAX LoRA configuration. ## `LoRAConfig` {#max.pipelines.lib.lora_config.LoRAConfig} > class max.pipelines.lib.lora\_config.LoRAConfig(\*, config\_file=None, section\_name=None, enable\_lora=False, lora\_paths=\, max\_lora\_rank=16, max\_num\_loras=1)
**Parameters:**
* config\_file ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * section\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * enable\_lora ([bool](https://docs.python.org/3/library/functions.html#bool)) * lora\_paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) * max\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int)) * max\_num\_loras ([int](https://docs.python.org/3/library/functions.html#int))
### `enable_lora` {#max.pipelines.lib.lora_config.LoRAConfig.enable_lora} > enable\_lora: [bool](https://docs.python.org/3/library/functions.html#bool) ### `lora_paths` {#max.pipelines.lib.lora_config.LoRAConfig.lora_paths} > lora\_paths: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] ### `max_lora_rank` {#max.pipelines.lib.lora_config.LoRAConfig.max_lora_rank} > max\_lora\_rank: [int](https://docs.python.org/3/library/functions.html#int) ### `max_num_loras` {#max.pipelines.lib.lora_config.LoRAConfig.max_num_loras} > max\_num\_loras: [int](https://docs.python.org/3/library/functions.html#int) ### `model_config` {#max.pipelines.lib.lora_config.LoRAConfig.model_config} > model\_config: ClassVar\[ConfigDict] = {} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `model_post_init()` {#max.pipelines.lib.lora_config.LoRAConfig.model_post_init} > model\_post\_init(context, /) This function is meant to behave like a BaseModel method to initialise private attributes. It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance. * context (Any) – The context.
**Return type:**
None
--- ## model_config MAX model config classes. ## `MAXModelConfig` {#max.pipelines.lib.model_config.MAXModelConfig} > class max.pipelines.lib.model\_config.MAXModelConfig(\*, config\_file=None, section\_name=None, use\_subgraphs=True, data\_parallel\_degree=1, model\_path='', served\_model\_name=None, weight\_path=\, quantization\_encoding=None, allow\_safetensors\_weights\_fp32\_bf6\_bidirectional\_cast=False, huggingface\_model\_revision='main', huggingface\_weight\_revision='main', trust\_remote\_code=False, device\_specs=\, force\_download=False, vision\_config\_overrides=\, rope\_type=None, kv\_cache=\) Bases: [`MAXModelConfigBase`](#max.pipelines.lib.model_config.MAXModelConfigBase)
**Parameters:**
* config\_file ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * section\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * use\_subgraphs ([bool](https://docs.python.org/3/library/functions.html#bool)) * data\_parallel\_degree ([int](https://docs.python.org/3/library/functions.html#int)) * model\_path ([str](https://docs.python.org/3/library/stdtypes.html#str)) * served\_model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * weight\_path ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) * quantization\_encoding (SupportedEncoding | None) * allow\_safetensors\_weights\_fp32\_bf6\_bidirectional\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) * huggingface\_model\_revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) * huggingface\_weight\_revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) * trust\_remote\_code ([bool](https://docs.python.org/3/library/functions.html#bool)) * device\_specs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](../driver.md#max.driver.DeviceSpec)]) * force\_download ([bool](https://docs.python.org/3/library/functions.html#bool)) * vision\_config\_overrides ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]) * rope\_type (RopeType | None) * kv\_cache (KVCacheConfig)
### `allow_safetensors_weights_fp32_bf6_bidirectional_cast` {#max.pipelines.lib.model_config.MAXModelConfig.allow_safetensors_weights_fp32_bf6_bidirectional_cast} > allow\_safetensors\_weights\_fp32\_bf6\_bidirectional\_cast: [bool](https://docs.python.org/3/library/functions.html#bool) ### `create_kv_cache_config()` {#max.pipelines.lib.model_config.MAXModelConfig.create_kv_cache_config} > create\_kv\_cache\_config(\*\*kv\_cache\_kwargs) Create and set the KV cache configuration with the given parameters. This method creates a new KVCacheConfig from the provided keyword arguments and automatically sets the cache\_dtype based on the model’s quantization encoding (or any explicit override in kv\_cache\_kwargs).
**Parameters:**
\*\*kv\_cache\_kwargs – Keyword arguments to pass to KVCacheConfig constructor. Common options include: * cache\_strategy: The KV cache strategy (continuous, paged, etc.) * kv\_cache\_page\_size: Number of tokens per page for paged cache * enable\_prefix\_caching: Whether to enable prefix caching * device\_memory\_utilization: Fraction of device memory to use * cache\_dtype: Override for the cache data type
**Return type:**
None
### `data_parallel_degree` {#max.pipelines.lib.model_config.MAXModelConfig.data_parallel_degree} > data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) ### `default_device_spec` {#max.pipelines.lib.model_config.MAXModelConfig.default_device_spec} > property default\_device\_spec: [DeviceSpec](../driver.md#max.driver.DeviceSpec) Returns the default device spec for the model. This is the first device spec in the list, used for device spec checks throughout config validation.
**Returns:**
The default device spec for the model.
### `device_specs` {#max.pipelines.lib.model_config.MAXModelConfig.device_specs} > device\_specs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](../driver.md#max.driver.DeviceSpec)] ### `diffusers_config` {#max.pipelines.lib.model_config.MAXModelConfig.diffusers_config} > property diffusers\_config: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [None](https://docs.python.org/3/library/constants.html#None) Retrieve the diffusers config for diffusion pipelines. Note: For multiprocessing, \_\_getstate\_\_ clears \_diffusers\_config before pickling. Each worker process will reload the config fresh.
**Returns:**
The diffusers config dict if this is a diffusion pipeline, None otherwise. The dict will have a structure with “\_class\_name” and “components” keys, where each component includes “class\_name” and “config\_dict” fields.
### `force_download` {#max.pipelines.lib.model_config.MAXModelConfig.force_download} > force\_download: [bool](https://docs.python.org/3/library/functions.html#bool) ### `generation_config` {#max.pipelines.lib.model_config.MAXModelConfig.generation_config} > property generation\_config: GenerationConfig Retrieve the Hugging Face GenerationConfig for this model. This property lazily loads the GenerationConfig from the model repository and caches it to avoid repeated remote fetches.
**Returns:**
The GenerationConfig for the model, containing generation parameters like max\_length, temperature, top\_p, etc. If loading fails, returns a default GenerationConfig.
### `graph_quantization_encoding` {#max.pipelines.lib.model_config.MAXModelConfig.graph_quantization_encoding} > property graph\_quantization\_encoding: [QuantizationEncoding](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) Converts the CLI encoding to a MAX Graph quantization encoding.
**Returns:**
The graph quantization encoding corresponding to the CLI encoding.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no CLI encoding was specified.
### `huggingface_config` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_config} > property huggingface\_config: AutoConfig | [None](https://docs.python.org/3/library/constants.html#None) Returns the Hugging Face model config (loaded on first access). ### `huggingface_model_repo` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_model_repo} > property huggingface\_model\_repo: [HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo) Returns the Hugging Face repo handle for the model. ### `huggingface_model_revision` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_model_revision} > huggingface\_model\_revision: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `huggingface_weight_repo` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_repo} > property huggingface\_weight\_repo: [HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo) Returns the Hugging Face repo handle for weight files. ### `huggingface_weight_repo_id` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_repo_id} > property huggingface\_weight\_repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str) Returns the Hugging Face repo ID used for weight files. ### `huggingface_weight_revision` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_revision} > huggingface\_weight\_revision: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `kv_cache` {#max.pipelines.lib.model_config.MAXModelConfig.kv_cache} > kv\_cache: KVCacheConfig ### `model_config` {#max.pipelines.lib.model_config.MAXModelConfig.model_config} > model\_config: ClassVar\[ConfigDict] = {'arbitrary\_types\_allowed': True} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. ### `model_name` {#max.pipelines.lib.model_config.MAXModelConfig.model_name} > property model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) Returns the served model name or model path. ### `model_path` {#max.pipelines.lib.model_config.MAXModelConfig.model_path} > model\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `model_post_init()` {#max.pipelines.lib.model_config.MAXModelConfig.model_post_init} > model\_post\_init(context, /) This function is meant to behave like a BaseModel method to initialise private attributes. It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance. * context (Any) – The context.
**Return type:**
None
### `quantization_encoding` {#max.pipelines.lib.model_config.MAXModelConfig.quantization_encoding} > quantization\_encoding: SupportedEncoding | [None](https://docs.python.org/3/library/constants.html#None) ### `resolve()` {#max.pipelines.lib.model_config.MAXModelConfig.resolve} > resolve() Validates and resolves the config. This method is called after the model config is initialized, to ensure that all config fields have been initialized to a valid state. It will also set and update other fields which may not be determined / initialized in the default factory. In order: 1. Validate that the device\_specs provided are available 2. Parse the weight path(s) and initialize the \_weights\_repo\_id
**Return type:**
None
### `rope_type` {#max.pipelines.lib.model_config.MAXModelConfig.rope_type} > rope\_type: RopeType | [None](https://docs.python.org/3/library/constants.html#None) ### `sampling_params_defaults` {#max.pipelines.lib.model_config.MAXModelConfig.sampling_params_defaults} > property sampling\_params\_defaults: [SamplingParamsGenerationConfigDefaults](../interfaces.md#max.interfaces.SamplingParamsGenerationConfigDefaults) Returns sampling defaults derived from the generation config. ### `served_model_name` {#max.pipelines.lib.model_config.MAXModelConfig.served_model_name} > served\_model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `set_cache_dtype_given_quantization_encoding()` {#max.pipelines.lib.model_config.MAXModelConfig.set_cache_dtype_given_quantization_encoding} > set\_cache\_dtype\_given\_quantization\_encoding() Determine the KV cache dtype based on quantization encoding configuration. The dtype is determined in the following priority order: 1. Explicit override from kv\_cache.kv\_cache\_format (if set) 2. Derived from the model’s quantization\_encoding 3. Falls back to float32 if no encoding is specified
**Returns:**
* DType.float32 for float32, q4\_k, q4\_0, q6\_k encodings * DType.bfloat16 for bfloat16, float8\_e4m3fn, float4\_e2m1fnx2, gptq encodings
**Return type:**
The DType to use for the KV cache. Typical values are
### `trust_remote_code` {#max.pipelines.lib.model_config.MAXModelConfig.trust_remote_code} > trust\_remote\_code: [bool](https://docs.python.org/3/library/functions.html#bool) ### `use_subgraphs` {#max.pipelines.lib.model_config.MAXModelConfig.use_subgraphs} > use\_subgraphs: [bool](https://docs.python.org/3/library/functions.html#bool) ### `validate_and_resolve_quantization_encoding_weight_path()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_quantization_encoding_weight_path} > validate\_and\_resolve\_quantization\_encoding\_weight\_path(default\_encoding) Verifies that the quantization encoding and weight path are consistent.
**Parameters:**
* weight\_path – The path to the weight file. * default\_encoding (SupportedEncoding) – The default encoding to use if no encoding is provided.
**Return type:**
None
### `validate_and_resolve_rope_type()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_rope_type} > validate\_and\_resolve\_rope\_type(arch\_rope\_type) Resolves rope\_type from architecture default if not set.
**Parameters:**
arch\_rope\_type (RopeType)
**Return type:**
None
### `validate_and_resolve_with_resolved_quantization_encoding()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_with_resolved_quantization_encoding} > validate\_and\_resolve\_with\_resolved\_quantization\_encoding(supported\_encodings, default\_weights\_format) Validates model path and weight path against resolved quantization encoding. Also resolves the KV cache strategy and finalizes the encoding config.
**Parameters:**
* supported\_encodings ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[SupportedEncoding, [list](https://docs.python.org/3/library/stdtypes.html#list)\[[KVCacheStrategy](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)]]) – A dictionary of supported encodings and their corresponding KV cache strategies. * default\_weights\_format ([WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)) – The default weights format to use if no weights format is provided.
**Return type:**
None
### `validate_lora_compatibility()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_lora_compatibility} > validate\_lora\_compatibility() Validates that LoRA configuration is compatible with model settings.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If LoRA is enabled but incompatible with current model configuration.
**Return type:**
None
### `validate_multi_gpu_supported()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_multi_gpu_supported} > validate\_multi\_gpu\_supported(multi\_gpu\_supported) Validates that the model architecture supports multi-GPU inference.
**Parameters:**
multi\_gpu\_supported ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the model architecture supports multi-GPU inference.
**Return type:**
None
### `vision_config_overrides` {#max.pipelines.lib.model_config.MAXModelConfig.vision_config_overrides} > vision\_config\_overrides: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), Any] ### `weight_path` {#max.pipelines.lib.model_config.MAXModelConfig.weight_path} > weight\_path: [list](https://docs.python.org/3/library/stdtypes.html#list)\[Path] ### `weights_size()` {#max.pipelines.lib.model_config.MAXModelConfig.weights_size} > weights\_size() Calculates the total size in bytes of all weight files in `weight_path`. Attempts to find the weights locally first to avoid network calls, checking in the following order: 1. If repo\_type is `RepoType.local`, it checks if the path in weight\_path exists directly as a local file path. 2. Otherwise, if repo\_type is `RepoType.online`, it first checks the local Hugging Face cache using `huggingface_hub.try_to_load_from_cache()`. If not found in the cache, it falls back to querying the Hugging Face Hub API via `HuggingFaceRepo.size_of()`.
**Returns:**
The total size of all weight files in bytes.
**Raises:**
* [FileNotFoundError](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) – If repo\_type is `RepoType.local` and a file specified in weight\_path is not found within the local repo directory. * [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `HuggingFaceRepo.size_of()` fails to retrieve the file size from the Hugging Face Hub API (e.g., file metadata not available or API error). * [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the determined repo\_type is unexpected.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
## `MAXModelConfigBase` {#max.pipelines.lib.model_config.MAXModelConfigBase} > class max.pipelines.lib.model\_config.MAXModelConfigBase(\*, config\_file=None, section\_name=None) Bases: `ConfigFileModel` Abstract base class for all (required) MAX model configs. This base class is used to configure a model to use for a pipeline, but also handy to sidestep the need to pass in optional fields when subclassing MAXModelConfig.
**Parameters:**
* config\_file ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * section\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
### `model_config` {#max.pipelines.lib.model_config.MAXModelConfigBase.model_config} > model\_config: ClassVar\[ConfigDict] = {'arbitrary\_types\_allowed': True} Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict]. --- ## pipeline MAX pipeline for model inference and generation (Text Generation variant). ## `BatchInfo` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo} > class max.pipelines.lib.pipeline\_variants.text\_generation.BatchInfo(past\_seq\_lens, seq\_lens, num\_steps) Information about a batch of requests passed to the pipeline.
**Parameters:**
* past\_seq\_lens ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * seq\_lens ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) * num\_steps ([int](https://docs.python.org/3/library/functions.html#int))
### `num_steps` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.num_steps} > num\_steps: [int](https://docs.python.org/3/library/functions.html#int) Number of steps to do in the pipeline ### `past_seq_lens` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.past_seq_lens} > past\_seq\_lens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Coordinated list of past sequence lengths (i.e. context lengths) ### `seq_lens` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.seq_lens} > seq\_lens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Coordinated list of sequence lengths, i.e. prompt\_len or 1 ## `TextGenerationPipeline` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline} > class max.pipelines.lib.pipeline\_variants.text\_generation.TextGenerationPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters, tokenizer) Generalized token generator pipeline.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * pipeline\_model ([type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[TextGenerationContextType]]) * eos\_token\_id ([int](https://docs.python.org/3/library/functions.html#int)) * weight\_adapters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), WeightsAdapter]) * tokenizer ([PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[TextGenerationContextType, npt.NDArray\[np.integer\[Any]], [TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest)])
### `execute()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.execute} > execute(inputs) Processes the batch and returns decoded tokens. Given a batch, executes the graph for num\_steps in a multi-step scenario, then decodes the tokens and returns the list of decoded tokens.
**Parameters:**
inputs ([TextGenerationInputs](../interfaces.md#max.interfaces.TextGenerationInputs)\[TextGenerationContextType])
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](../interfaces.md#max.interfaces.RequestID), [TextGenerationOutput](../interfaces.md#max.interfaces.TextGenerationOutput)]
### `initialize_bitmask()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.initialize_bitmask} > initialize\_bitmask(batch) Allocates a per-request token bitmask for structured decoding.
**Parameters:**
batch ([list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]) – The generation contexts for the batch.
**Returns:**
A bitmask array of shape \[batch\_size, vocab\_size] if structured output is enabled; otherwise `None`.
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int32]] | None
### `kv_managers` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.kv_managers} > property kv\_managers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] Return the list of KV cache managers backing this pipeline. ### `pipeline_config` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.pipeline_config} > property pipeline\_config: [PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig) Return the pipeline configuration. ### `prepare_batch()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.prepare_batch} > prepare\_batch(batches, num\_steps) Prepare model inputs and ancillary state for multi-step execution. This flattens replica batches, optionally initializes constrained decoding bitmasks, ensures KV-cache reservations, clamps `num_steps` per context, and builds initial model inputs.
**Parameters:**
* batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]]) – Per-replica list of contexts. * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Desired number of steps to run.
**Returns:**
* ModelInputs: Prepared inputs for the first step. * int: The clamped number of steps to run. * Optional\[np.ndarray]: The structured decoding bitmask or None. * list\[TextGenerationContextType]: The flattened context batch.
**Return type:**
A tuple of
### `release()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.release} > release(request\_id) Mark the context as complete, releasing the cache slot from the KV manager. Note: KV cache lifecycle is now managed by the scheduler. This method is kept for interface compatibility but is a no-op for regular pipelines.
**Parameters:**
request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID))
**Return type:**
None
### `tokenizer` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.tokenizer} > property tokenizer: [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[TextGenerationContextType, [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]], [TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest)] Return the tokenizer used for building contexts and decoding. ### `update_for_structured_output()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.update_for_structured_output} > update\_for\_structured\_output(context, bitmask, index) Update context and logits bitmask for structured output. If a `json_schema` is present and no matcher is set, this compiles a grammar matcher and installs it on the context. It may also jump ahead in generation and fills the per-request token bitmask used to constrain the next-token distribution.
**Parameters:**
* context (TextGenerationContextType) – Request context to update. * bitmask ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int32]]) – Optional preallocated bitmask buffer; updated in-place. * index ([int](https://docs.python.org/3/library/functions.html#int)) – Global position into the bitmask for this request.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a JSON schema is provided but structured output is not enabled via sampling configuration.
**Return type:**
None
## `StandaloneSpeculativeDecodingPipeline` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline} > final class max.pipelines.lib.speculative\_decoding.StandaloneSpeculativeDecodingPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters, tokenizer, draft\_pipeline\_model=None, draft\_weight\_adapters=None) Bases: `SpeculativeDecodingPipelineBase` Standalone speculative decoding where draft model runs independently. In this approach, the draft model generates tokens without any information from the target model, then the target model verifies these tokens.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * pipeline\_model ([type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[TextContext](core.md#max.pipelines.core.TextContext)]]) * eos\_token\_id ([int](https://docs.python.org/3/library/functions.html#int)) * weight\_adapters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), WeightsAdapter]) * tokenizer ([PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[TextContext](core.md#max.pipelines.core.TextContext), npt.NDArray\[np.integer\[Any]], [TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest)]) * draft\_pipeline\_model ([type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[TextContext](core.md#max.pipelines.core.TextContext)]] | None) * draft\_weight\_adapters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), WeightsAdapter] | None)
### `execute()` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline.execute} > execute(inputs) Execute standalone speculative decoding. In standalone mode: 1. Draft model generates tokens independently 2. Target model verifies draft tokens 3. Apply rejection sampling to accept/reject tokens
**Parameters:**
inputs ([TextGenerationInputs](../interfaces.md#max.interfaces.TextGenerationInputs)\[[TextContext](core.md#max.pipelines.core.TextContext)])
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](../interfaces.md#max.interfaces.RequestID), [TextGenerationOutput](../interfaces.md#max.interfaces.TextGenerationOutput)]
### `generate_draft_tokens()` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline.generate_draft_tokens} > generate\_draft\_tokens(batch, num\_steps, model\_inputs) Generates draft tokens for the batch using the draft model.
**Parameters:**
* batch ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](core.md#max.pipelines.core.TextContext)]) * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * model\_inputs ([ModelInputs](interfaces.md#max.pipelines.lib.interfaces.ModelInputs))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [Buffer](../driver.md#max.driver.Buffer), [Buffer](../driver.md#max.driver.Buffer), [ModelInputs](interfaces.md#max.pipelines.lib.interfaces.ModelInputs), [Buffer](../driver.md#max.driver.Buffer)]
### `prepare_batch()` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline.prepare_batch} > prepare\_batch(model, batch, replica\_batches, num\_steps, return\_n\_logits, is\_draft=False, draft\_inputs=None, merged\_draft\_tokens=None, merged\_draft\_offsets=None) Prepares batch inputs and KV cache for draft or target model.
**Parameters:**
* model ([PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[TextContext](core.md#max.pipelines.core.TextContext)]) * batch ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](core.md#max.pipelines.core.TextContext)]) * replica\_batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](core.md#max.pipelines.core.TextContext)]]) * num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) * return\_n\_logits ([int](https://docs.python.org/3/library/functions.html#int)) * is\_draft ([bool](https://docs.python.org/3/library/functions.html#bool)) * draft\_inputs ([ModelInputs](interfaces.md#max.pipelines.lib.interfaces.ModelInputs) | None) * merged\_draft\_tokens ([Buffer](../driver.md#max.driver.Buffer) | None) * merged\_draft\_offsets ([Buffer](../driver.md#max.driver.Buffer) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[ModelInputs](interfaces.md#max.pipelines.lib.interfaces.ModelInputs), [int](https://docs.python.org/3/library/functions.html#int)]
### `verify_draft_tokens_with_target_model()` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline.verify_draft_tokens_with_target_model} > verify\_draft\_tokens\_with\_target\_model(draft\_inputs, context\_batch, replica\_batches, num\_draft\_tokens\_generated, draft\_tokens, draft\_logits, merged\_draft\_tokens, merged\_draft\_offsets, all\_draft\_logits) Verifies draft tokens against the target model and returns merged outputs.
**Parameters:**
* draft\_inputs ([ModelInputs](interfaces.md#max.pipelines.lib.interfaces.ModelInputs)) * context\_batch ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](core.md#max.pipelines.core.TextContext)]) * replica\_batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextContext](core.md#max.pipelines.core.TextContext)]]) * num\_draft\_tokens\_generated ([int](https://docs.python.org/3/library/functions.html#int)) * draft\_tokens ([Buffer](../driver.md#max.driver.Buffer)) * draft\_logits ([Buffer](../driver.md#max.driver.Buffer)) * merged\_draft\_tokens ([Buffer](../driver.md#max.driver.Buffer)) * merged\_draft\_offsets ([Buffer](../driver.md#max.driver.Buffer)) * all\_draft\_logits ([Buffer](../driver.md#max.driver.Buffer))
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Buffer](../driver.md#max.driver.Buffer), [Buffer](../driver.md#max.driver.Buffer), [Buffer](../driver.md#max.driver.Buffer)]
## `EmbeddingsPipeline` {#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline} > final class max.pipelines.lib.embeddings\_pipeline.EmbeddingsPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters, tokenizer) Bases: [`Pipeline`](../interfaces.md#max.interfaces.Pipeline)\[[`EmbeddingsGenerationInputs`](../interfaces.md#max.interfaces.EmbeddingsGenerationInputs), [`EmbeddingsGenerationOutput`](../interfaces.md#max.interfaces.EmbeddingsGenerationOutput)] Generalized token generator pipeline.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * pipeline\_model ([type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[EmbeddingsContext](../interfaces.md#max.interfaces.EmbeddingsContext)]]) * eos\_token\_id ([int](https://docs.python.org/3/library/functions.html#int)) * weight\_adapters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), WeightsAdapter]) * tokenizer ([PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[BaseContextType, npt.NDArray\[np.integer\[Any]], [TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest)])
### `execute()` {#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline.execute} > execute(inputs) Processes the batch and returns embeddings. Given a batch, executes the graph and returns the list of embedding outputs per request.
**Parameters:**
inputs ([EmbeddingsGenerationInputs](../interfaces.md#max.interfaces.EmbeddingsGenerationInputs))
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](../interfaces.md#max.interfaces.RequestID), [EmbeddingsGenerationOutput](../interfaces.md#max.interfaces.EmbeddingsGenerationOutput)]
### `release()` {#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline.release} > release(request\_id) Releases resources for the request (no-op for embeddings).
**Parameters:**
request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID))
**Return type:**
None
--- ## registry (Pipelines) Model registry, for tracking various model variants. ## `PipelineRegistry` {#max.pipelines.lib.registry.PipelineRegistry} > class max.pipelines.lib.registry.PipelineRegistry(architectures) Registry for managing supported model architectures and their pipelines. This class maintains a collection of [`SupportedArchitecture`](#max.pipelines.lib.registry.SupportedArchitecture) instances, each defining how a particular model architecture should be loaded, configured, and executed. :::note Note Do not instantiate this class directly. Always use the global [`PIPELINE_REGISTRY`](#max.pipelines.lib.registry.PIPELINE_REGISTRY) singleton, which is automatically populated with all built-in architectures when you import `max.pipelines`. ::: Use [`PIPELINE_REGISTRY`](#max.pipelines.lib.registry.PIPELINE_REGISTRY) when you want to: * **Register a custom architectures**: Call [`register()`](#max.pipelines.lib.registry.PipelineRegistry.register) to add a new MAX model architecture to the registry before loading it. * **Query supported models**: Call [`retrieve_architecture()`](#max.pipelines.lib.registry.PipelineRegistry.retrieve_architecture) to check if a Hugging Face model repository is supported before attempting to load it. * **Access cached configs**: Methods like [`get_active_huggingface_config()`](#max.pipelines.lib.registry.PipelineRegistry.get_active_huggingface_config) and [`get_active_tokenizer()`](#max.pipelines.lib.registry.PipelineRegistry.get_active_tokenizer) provide cached access to model configurations and tokenizers.
**Parameters:**
architectures ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[SupportedArchitecture](#max.pipelines.lib.registry.SupportedArchitecture)])
### `get_active_diffusers_config()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_diffusers_config} > get\_active\_diffusers\_config(huggingface\_repo) Retrieves or creates a cached diffusers config for the given repository. This method checks if the repository is a diffusion pipeline by looking for model\_index.json. If found, it downloads and caches the config. If not found, returns None.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The diffusers config dict if this is a diffusion pipeline, None otherwise.
**Return type:**
[dict](https://docs.python.org/3/library/stdtypes.html#dict) | None
### `get_active_huggingface_config()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_huggingface_config} > get\_active\_huggingface\_config(huggingface\_repo) Retrieves or creates a cached Hugging Face AutoConfig for the given model. Maintains a cache of Hugging Face configurations to avoid reloading them unnecessarily which incurs a Hugging Face Hub API call. If a config for the given model hasn’t been loaded before, it will create a new one using AutoConfig.from\_pretrained() with the model’s settings. Note: The cache key (HuggingFaceRepo) includes trust\_remote\_code in its hash, so configs with different trust settings are cached separately. For multiprocessing, each worker process has its own registry instance with an empty cache, so configs are loaded fresh in each worker.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The Hugging Face configuration object for the model.
**Return type:**
AutoConfig
### `get_active_tokenizer()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_tokenizer} > get\_active\_tokenizer(huggingface\_repo) Retrieves or creates a cached Hugging Face AutoTokenizer for the given model. Maintains a cache of Hugging Face tokenizers to avoid reloading them unnecessarily which incurs a Hugging Face Hub API call. If a tokenizer for the given model hasn’t been loaded before, it will create a new one using AutoTokenizer.from\_pretrained() with the model’s settings.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The Hugging Face tokenizer for the model.
**Return type:**
PreTrainedTokenizer | PreTrainedTokenizerFast
### `register()` {#max.pipelines.lib.registry.PipelineRegistry.register} > register(architecture, \*, allow\_override=False) Add new architecture to registry. If multiple architectures share the same name but have different tasks, they are registered in a secondary lookup table keyed by (name, task).
**Parameters:**
* architecture ([SupportedArchitecture](#max.pipelines.lib.registry.SupportedArchitecture)) * allow\_override ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
None
### `reset()` {#max.pipelines.lib.registry.PipelineRegistry.reset} > reset() Clears all registered architectures (mainly for tests).
**Return type:**
None
### `retrieve()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve} > retrieve(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None) Retrieves the tokenizer and an instantiated pipeline for the config.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask)) * override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[Any, Any, Any], PipelineTypes]
### `retrieve_architecture()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_architecture} > retrieve\_architecture(huggingface\_repo, use\_legacy\_module=True, task=None) Retrieve architecture matching the Hugging Face model config.
**Parameters:**
* huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The Hugging Face repository to match against. * use\_legacy\_module ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use legacy Module architecture (default=True). When True, appends “\_Legacy” suffix to find legacy graph-based architecture. When False, uses the standard Hugging Face architecture name for new API. * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional task to disambiguate when multiple architectures share the same name. If not provided and multiple architectures share the same name, the task will be inferred from the Hugging Face Hub’s pipeline\_tag.
**Returns:**
The matching SupportedArchitecture or None if no match found.
**Return type:**
[SupportedArchitecture](#max.pipelines.lib.registry.SupportedArchitecture) | None
### `retrieve_context_type()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_context_type} > retrieve\_context\_type(pipeline\_config, override\_architecture=None, task=None) Retrieve the context class type associated with the architecture for the given pipeline configuration. The context type defines how the pipeline manages request state and inputs during model execution. Different architectures may use different context implementations that adhere to either the TextGenerationContext or EmbeddingsContext protocol.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – The configuration for the pipeline. * override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional architecture name to use instead of looking up based on the model repository. This is useful for cases like audio generation where the pipeline uses a different architecture (e.g., audio decoder) than the underlying model repository. * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional pipeline task to disambiguate when multiple architectures share the same name but serve different tasks.
**Returns:**
The context class type associated with the architecture, which implements either the TextGenerationContext or EmbeddingsContext protocol.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no supported architecture is found for the given model repository or override architecture name.
**Return type:**
[type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationContext](../interfaces.md#max.interfaces.TextGenerationContext)] | [type](https://docs.python.org/3/library/functions.html#type)\[[EmbeddingsContext](../interfaces.md#max.interfaces.EmbeddingsContext)]
### `retrieve_factory()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_factory} > retrieve\_factory(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None) Retrieves the tokenizer and a factory that creates the pipeline instance.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask)) * override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[Any, Any, Any], [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[], PipelineTypes]]
### `retrieve_pipeline_task()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_pipeline_task} > retrieve\_pipeline\_task(pipeline\_config) Retrieves the pipeline task for the given pipeline configuration.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – The configuration for the pipeline.
**Returns:**
The task associated with the architecture.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no supported architecture is found for the given model repository.
**Return type:**
[PipelineTask](../interfaces.md#max.interfaces.PipelineTask)
### `retrieve_tokenizer()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_tokenizer} > retrieve\_tokenizer(pipeline\_config, override\_architecture=None, task=None) Retrieves a tokenizer for the given pipeline configuration.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Configuration for the pipeline * override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional architecture override string * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional pipeline task to disambiguate when multiple architectures share the same name but serve different tasks.
**Returns:**
The configured tokenizer
**Return type:**
[PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no architecture is found
## `SupportedArchitecture` {#max.pipelines.lib.registry.SupportedArchitecture} > class max.pipelines.lib.registry.SupportedArchitecture(name, example\_repo\_ids, default\_encoding, supported\_encodings, pipeline\_model, task, tokenizer, default\_weights\_format, context\_type, config, rope\_type=RopeType.none, weight\_adapters=\, multi\_gpu\_supported=False, required\_arguments=\, context\_validators=\, supports\_empty\_batches=False, requires\_max\_batch\_context\_length=False) Represents a model architecture configuration for MAX pipelines. Defines the components and settings required to support a specific model architecture within the MAX pipeline system. Each SupportedArchitecture instance encapsulates the model implementation, tokenizer, supported encodings, and other architecture-specific configuration. New architectures should be registered into the [`PipelineRegistry`](#max.pipelines.lib.registry.PipelineRegistry) using the [`register()`](#max.pipelines.lib.registry.PipelineRegistry.register) method. **Example:** ```python my_architecture = SupportedArchitecture( name="MyModelForCausalLM", # Must match your Hugging Face model class name example_repo_ids=[ "your-org/your-model-name", # Add example model repository IDs ], default_encoding=SupportedEncoding.q4_k, supported_encodings={ SupportedEncoding.q4_k: [KVCacheStrategy.PAGED], SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], # Add other encodings your model supports }, pipeline_model=MyModel, tokenizer=TextTokenizer, context_type=TextContext, config=MyModelConfig, # Architecture-specific config class default_weights_format=WeightsFormat.safetensors, rope_type=RopeType.none, weight_adapters={ WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, # Add other weight formats if needed }, multi_gpu_supported=True, # Set based on your implementation capabilities required_arguments={"some_arg": True}, task=PipelineTask.TEXT_GENERATION, ) ```
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) * example\_repo\_ids ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) * default\_encoding (SupportedEncoding) * supported\_encodings ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[SupportedEncoding, [list](https://docs.python.org/3/library/stdtypes.html#list)\[[KVCacheStrategy](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)]]) * pipeline\_model ([type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask)) * tokenizer ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * default\_weights\_format ([WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)) * context\_type ([type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationContext](../interfaces.md#max.interfaces.TextGenerationContext)] | [type](https://docs.python.org/3/library/functions.html#type)\[[EmbeddingsContext](../interfaces.md#max.interfaces.EmbeddingsContext)]) * config ([type](https://docs.python.org/3/library/functions.html#type)\[[ArchConfig](interfaces.md#max.pipelines.lib.interfaces.ArchConfig)]) * rope\_type (RopeType) * weight\_adapters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [WeightData](../graph/weights.md#max.graph.weights.WeightData)]]]) * multi\_gpu\_supported ([bool](https://docs.python.org/3/library/functions.html#bool)) * required\_arguments ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float)]) * context\_validators ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextContext](core.md#max.pipelines.core.TextContext) | [TextAndVisionContext](core.md#max.pipelines.core.TextAndVisionContext)], None]]) * supports\_empty\_batches ([bool](https://docs.python.org/3/library/functions.html#bool)) * requires\_max\_batch\_context\_length ([bool](https://docs.python.org/3/library/functions.html#bool))
### `config` {#max.pipelines.lib.registry.SupportedArchitecture.config} > config: [type](https://docs.python.org/3/library/functions.html#type)\[[ArchConfig](interfaces.md#max.pipelines.lib.interfaces.ArchConfig)] The architecture-specific configuration class for the model. This class must implement the `ArchConfig` protocol, providing an `initialize` method that creates a configuration instance from a `PipelineConfig`. For models with KV cache, this should be a class implementing `ArchConfigWithKVCache` to enable KV cache memory estimation. ### `context_type` {#max.pipelines.lib.registry.SupportedArchitecture.context_type} > context\_type: [type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationContext](../interfaces.md#max.interfaces.TextGenerationContext)] | [type](https://docs.python.org/3/library/functions.html#type)\[[EmbeddingsContext](../interfaces.md#max.interfaces.EmbeddingsContext)] The context class type that this architecture uses for managing request state and inputs. This should be a class (not an instance) that implements either the TextGenerationContext or EmbeddingsContext protocol, defining how the pipeline processes and tracks requests. ### `context_validators` {#max.pipelines.lib.registry.SupportedArchitecture.context_validators} > context\_validators: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextContext](core.md#max.pipelines.core.TextContext) | [TextAndVisionContext](core.md#max.pipelines.core.TextAndVisionContext)], [None](https://docs.python.org/3/library/constants.html#None)]] A list of callable validators that verify context inputs before model execution. These validators are called during context creation to ensure inputs meet model-specific requirements. Validators should raise InputError for invalid inputs, providing early error detection before expensive model operations. ```python def validate_single_image(context: TextContext | TextAndVisionContext) -> None: if isinstance(context, TextAndVisionContext): if context.pixel_values and len(context.pixel_values) > 1: raise InputError(f"Model supports only 1 image, got {len(context.pixel_values)}") my_architecture = SupportedArchitecture( # ... other fields ... context_validators=[validate_single_image], ) ``` ### `default_encoding` {#max.pipelines.lib.registry.SupportedArchitecture.default_encoding} > default\_encoding: SupportedEncoding The default quantization encoding to use when no specific encoding is requested. ### `default_weights_format` {#max.pipelines.lib.registry.SupportedArchitecture.default_weights_format} > default\_weights\_format: [WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat) The weights format expected by the pipeline\_model. ### `example_repo_ids` {#max.pipelines.lib.registry.SupportedArchitecture.example_repo_ids} > example\_repo\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] A list of Hugging Face repository IDs that use this architecture for testing and validation purposes. ### `multi_gpu_supported` {#max.pipelines.lib.registry.SupportedArchitecture.multi_gpu_supported} > multi\_gpu\_supported: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the architecture supports multi-GPU execution. ### `name` {#max.pipelines.lib.registry.SupportedArchitecture.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the model architecture that must match the Hugging Face model class name. ### `pipeline_model` {#max.pipelines.lib.registry.SupportedArchitecture.pipeline_model} > pipeline\_model: [type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] The PipelineModel class that defines the model graph structure and execution logic. ### `required_arguments` {#max.pipelines.lib.registry.SupportedArchitecture.required_arguments} > required\_arguments: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float)] A dictionary specifying required values for PipelineConfig options. ### `requires_max_batch_context_length` {#max.pipelines.lib.registry.SupportedArchitecture.requires_max_batch_context_length} > requires\_max\_batch\_context\_length: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the architecture requires a max batch context length to be specified. If True and max\_batch\_context\_length is not specified, we will default to the max sequence length of the model. ### `rope_type` {#max.pipelines.lib.registry.SupportedArchitecture.rope_type} > rope\_type: RopeType = 'none' The type of RoPE (Rotary Position Embedding) used by the model. ### `supported_encodings` {#max.pipelines.lib.registry.SupportedArchitecture.supported_encodings} > supported\_encodings: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[SupportedEncoding, [list](https://docs.python.org/3/library/stdtypes.html#list)\[[KVCacheStrategy](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)]] A dictionary mapping supported quantization encodings to their compatible KV cache strategies. ### `supports_empty_batches` {#max.pipelines.lib.registry.SupportedArchitecture.supports_empty_batches} > supports\_empty\_batches: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the architecture can handle empty batches during inference. When set to True, the pipeline can process requests with zero-sized batches without errors. This is useful for certain execution modes and expert parallelism. Most architectures do not require empty batch support and should leave this as False. ### `task` {#max.pipelines.lib.registry.SupportedArchitecture.task} > task: [PipelineTask](../interfaces.md#max.interfaces.PipelineTask) The pipeline task type that this architecture supports. ### `tokenizer` {#max.pipelines.lib.registry.SupportedArchitecture.tokenizer} > tokenizer: [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]] A callable that returns a PipelineTokenizer instance for preprocessing model inputs. ### `tokenizer_cls` {#max.pipelines.lib.registry.SupportedArchitecture.tokenizer_cls} > property tokenizer\_cls: [type](https://docs.python.org/3/library/functions.html#type)\[[PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]] Returns the tokenizer class for this architecture. ### `weight_adapters` {#max.pipelines.lib.registry.SupportedArchitecture.weight_adapters} > weight\_adapters: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [WeightData](../graph/weights.md#max.graph.weights.WeightData)]]] A dictionary of weight format adapters for converting checkpoints from different formats to the default format. ## `get_pipeline_for_task()` {#max.pipelines.lib.registry.get_pipeline_for_task} > max.pipelines.lib.registry.get\_pipeline\_for\_task(task, pipeline\_config)
**Parameters:**
* task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask)) * pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig))
**Return type:**
[type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationPipeline](pipeline.md#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline)\[[TextContext](core.md#max.pipelines.core.TextContext)]] | [type](https://docs.python.org/3/library/functions.html#type)\[[EmbeddingsPipeline](pipeline.md#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline)] | [type](https://docs.python.org/3/library/functions.html#type)\[AudioGeneratorPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[[StandaloneSpeculativeDecodingPipeline](pipeline.md#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline)] | [type](https://docs.python.org/3/library/functions.html#type)\[SpeechTokenGenerationPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[EAGLESpeculativeDecodingPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[OverlapTextGenerationPipeline\[[TextContext](core.md#max.pipelines.core.TextContext)]]
## `PIPELINE_REGISTRY` {#max.pipelines.lib.registry.PIPELINE_REGISTRY} > max.pipelines.lib.registry.PIPELINE\_REGISTRY: [PipelineRegistry](#max.pipelines.lib.registry.PipelineRegistry) Global registry of supported model architectures. This is the singleton [`PipelineRegistry`](#max.pipelines.lib.registry.PipelineRegistry) instance you can use to register new MAX model architectures and query supported models. --- ## sampling (Pipelines) ## `rejection_sampler()` {#max.pipelines.lib.sampling.sampling.rejection_sampler} > max.pipelines.lib.sampling.sampling.rejection\_sampler(device, \*, seed=0)
**Parameters:**
* device ([DeviceRef](../graph/type.md#max.graph.type.DeviceRef)) * seed ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[Graph](../graph/Graph.md#max.graph.Graph)
## `rejection_sampler_with_residuals()` {#max.pipelines.lib.sampling.sampling.rejection_sampler_with_residuals} > max.pipelines.lib.sampling.sampling.rejection\_sampler\_with\_residuals(device, \*, seed=0, debug=False) Builds a rejection sampler with residual sampling for speculative decoding. Computes acceptance ratios for draft tokens, finds first rejection, samples from residual distribution (target - draft), and generates bonus tokens.
**Parameters:**
* device ([DeviceRef](../graph/type.md#max.graph.type.DeviceRef)) * seed ([int](https://docs.python.org/3/library/functions.html#int)) * debug ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[Graph](../graph/Graph.md#max.graph.Graph)
## `token_sampler()` {#max.pipelines.lib.sampling.sampling.token_sampler} > max.pipelines.lib.sampling.sampling.token\_sampler(sampling\_config, device, return\_logits=False)
**Parameters:**
* sampling\_config (SamplingConfig) * device ([DeviceRef](../graph/type.md#max.graph.type.DeviceRef)) * return\_logits ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[Graph](../graph/Graph.md#max.graph.Graph)
--- ## tokenizer Implementations of provided tokenizers. ## `IdentityPipelineTokenizer` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer} > class max.pipelines.lib.tokenizer.IdentityPipelineTokenizer(\*args, \*\*kwargs) ### `decode()` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.decode} > async decode(encoded, \*\*kwargs) Returns the encoded string unchanged (identity decoding).
**Parameters:**
encoded ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `encode()` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens=False) Returns the prompt unchanged (identity encoding).
**Parameters:**
* prompt ([str](https://docs.python.org/3/library/stdtypes.html#str)) * add\_special\_tokens ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `eos` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) Returns the end-of-sequence token ID (0 for identity). ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this tokenizer expects content wrapping. ## `PreTrainedPipelineTokenizer` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer} > class max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer(delegate)
**Parameters:**
delegate (PreTrainedTokenizer | PreTrainedTokenizerFast)
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.apply_chat_template} > apply\_chat\_template(messages) Applies the delegate’s chat template to the messages.
**Parameters:**
messages ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](../interfaces.md#max.interfaces.TextGenerationRequestMessage)])
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `decode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.decode} > async decode(encoded, \*\*kwargs) Decodes token ids to text via the delegate.
**Parameters:**
encoded ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `encode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens=False) Encodes the prompt to token ids via the delegate.
**Parameters:**
* prompt ([str](https://docs.python.org/3/library/stdtypes.html#str)) * add\_special\_tokens ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
### `eos` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) Returns the end-of-sequence token ID from the delegate. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this tokenizer expects content wrapping. ## `TextAndVisionTokenizer` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer} > class max.pipelines.lib.tokenizer.TextAndVisionTokenizer(model\_path, pipeline\_config, \*, revision=None, max\_length=None, trust\_remote\_code=False, context\_validators=None, \*\*unused\_kwargs) Encapsulates creation of TextAndVisionContext and specific token encode/decode logic.
**Parameters:**
* model\_path ([str](https://docs.python.org/3/library/stdtypes.html#str)) * pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) * revision ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) * trust\_remote\_code ([bool](https://docs.python.org/3/library/functions.html#bool)) * context\_validators ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextAndVisionContext](core.md#max.pipelines.core.TextAndVisionContext)], None]] | None)
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.apply_chat_template} > apply\_chat\_template(messages) Applies the processor’s chat template to the messages.
**Parameters:**
messages ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](../interfaces.md#max.interfaces.TextGenerationRequestMessage)])
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `decode()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.decode} > async decode(encoded, \*\*kwargs) Transformer a provided encoded token array, back into readable text.
**Parameters:**
encoded ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `encode()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.encode} > async encode(prompt, add\_special\_tokens=True) Transforms the provided prompt into a token array.
**Parameters:**
* prompt ([str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)]) * add\_special\_tokens ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
### `eos` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) Returns the end-of-sequence token ID from the delegate. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this tokenizer expects content wrapping. ### `new_context()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.new_context} > async new\_context(request) Create a new TextAndVisionContext object, leveraging necessary information from TextGenerationRequest.
**Parameters:**
request ([TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest))
**Return type:**
[TextAndVisionContext](core.md#max.pipelines.core.TextAndVisionContext)
## `TextTokenizer` {#max.pipelines.lib.tokenizer.TextTokenizer} > class max.pipelines.lib.tokenizer.TextTokenizer(model\_path, pipeline\_config, \*, revision=None, max\_length=None, trust\_remote\_code=False, enable\_llama\_whitespace\_fix=False, chat\_template=None, context\_validators=None, \*\*unused\_kwargs) Encapsulates creation of TextContext and specific token encode/decode logic.
**Parameters:**
* model\_path ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Path to the model/tokenizer * revision ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Git revision/branch to use * max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) – Maximum sequence length * trust\_remote\_code ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to trust remote code from the model * enable\_llama\_whitespace\_fix ([bool](https://docs.python.org/3/library/functions.html#bool)) – Enable whitespace fix for Llama tokenizers * pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Optional pipeline configuration * chat\_template ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional custom chat template string to override the one shipped with the Hugging Face model config. This allows customizing the prompt formatting for different use cases. * context\_validators ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextContext](core.md#max.pipelines.core.TextContext)], None]] | None)
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextTokenizer.apply_chat_template} > apply\_chat\_template(messages, tools, chat\_template\_options=None) Applies the delegate chat template to messages (and optional tools).
**Parameters:**
* messages ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](../interfaces.md#max.interfaces.TextGenerationRequestMessage)]) * tools ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestTool](../interfaces.md#max.interfaces.TextGenerationRequestTool)] | None) * chat\_template\_options ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None)
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `decode()` {#max.pipelines.lib.tokenizer.TextTokenizer.decode} > async decode(encoded, \*\*kwargs) Transformer a provided encoded token array, back into readable text.
**Parameters:**
encoded ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
**Return type:**
[str](https://docs.python.org/3/library/stdtypes.html#str)
### `encode()` {#max.pipelines.lib.tokenizer.TextTokenizer.encode} > async encode(prompt, add\_special\_tokens=True) Transforms the provided prompt into a token array.
**Parameters:**
* prompt ([str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)]) * add\_special\_tokens ([bool](https://docs.python.org/3/library/functions.html#bool))
**Return type:**
[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), …], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
### `eos` {#max.pipelines.lib.tokenizer.TextTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) Returns the end-of-sequence token ID from the delegate. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this tokenizer expects content wrapping. ### `new_context()` {#max.pipelines.lib.tokenizer.TextTokenizer.new_context} > async new\_context(request) Create a new TextContext object, leveraging necessary information from TextGenerationRequest.
**Parameters:**
request ([TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest))
**Return type:**
[TextContext](core.md#max.pipelines.core.TextContext)
## `max_tokens_to_generate()` {#max.pipelines.lib.tokenizer.max_tokens_to_generate} > max.pipelines.lib.tokenizer.max\_tokens\_to\_generate(prompt\_size, max\_length, max\_new\_tokens=None) Returns the max number of new tokens to generate.
**Parameters:**
* prompt\_size ([int](https://docs.python.org/3/library/functions.html#int)) * max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) * max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None)
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int) | None
## `run_with_default_executor()` {#max.pipelines.lib.tokenizer.run_with_default_executor} > async max.pipelines.lib.tokenizer.run\_with\_default\_executor(fn, \*args, \*\*kwargs)
**Parameters:**
* fn ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[\~\_P], \_R]) * args (\~\_P) * kwargs (\~\_P)
**Return type:**
\_R
--- ## profiler Performance profiling and tracing utilities for MAX. This module provides tools for profiling and tracing MAX operations to analyze performance characteristics. Profiling captures timing information for code execution, which helps identify bottlenecks and optimize your models. To enable profiling, set the `MODULAR_ENABLE_PROFILING=1` environment variable before running your code. Without this variable, profiling calls will be no-ops with minimal overhead. The profiler supports three usage patterns: 1. **Context manager**: Use [`Tracer`](#max.profiler.Tracer) as a context manager to profile a code block. 2. **Decorator**: Use [`@traced`](#max.profiler.traced) to profile entire functions. 3. **Manual stack**: Use [`Tracer`](#max.profiler.Tracer) methods to explicitly control profiling spans. ## `Tracer` {#max.profiler.Tracer} > class max.profiler.Tracer(message=None, color='modular\_purple') A stack-based profiling manager for creating nested profiling spans. Manages a stack of profiling spans that allows for nested tracing without requiring deeply nested `with Trace(name):` statements. This is especially useful when you need to dynamically create and manage profiling spans based on runtime conditions or when profiling spans don’t align with your code’s block structure. The `Tracer` can be used both as a context manager and as a manual stack manager. As a context manager, it ensures all pushed spans are properly closed when the context exits. ```python from max.profiler import Tracer tracer = Tracer("parent_operation", color="modular_purple") tracer.push("child_operation") # ... perform work ... tracer.pop() # Context manager with manual stack with Tracer("parent_operation", color="modular_purple") as tracer: # The parent span is named "parent_operation" tracer.push("child_operation") # ... perform work ... tracer.pop() # All spans are automatically closed on context exit ```
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) * color ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `cleanup()` {#max.profiler.Tracer.cleanup} > cleanup() Closes all remaining profiling spans. Pops and closes all profiling spans that were pushed onto the stack. This method is automatically called when the tracer is used as a context manager or when the object is deleted.
**Return type:**
None
### `mark()` {#max.profiler.Tracer.mark} > mark() Marks the current profiling span with a timestamp. Records a timestamp event within the current profiling span. This is useful for marking significant events or milestones within a longer operation.
**Raises:**
[AssertionError](https://docs.python.org/3/library/exceptions.html#AssertionError) – If the stack is empty when mark is called.
**Return type:**
None
### `next()` {#max.profiler.Tracer.next} > next(message, color='modular\_purple') Transitions to the next profiling span. Pops the current profiling span and immediately pushes a new one with the specified message. This is a convenience method for sequential operations at the same nesting level.
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the new profiling span. * color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Return type:**
None
### `pop()` {#max.profiler.Tracer.pop} > pop(exc\_type=None, exc\_value=None, traceback=None) Pops a profiling span off the stack and closes it. Removes the most recently pushed profiling span from the stack and closes it, recording its execution time. Exception information can be passed through for proper error handling in context managers.
**Parameters:**
* exc\_type ([type](https://docs.python.org/3/library/functions.html#type)\[[BaseException](https://docs.python.org/3/library/exceptions.html#BaseException)] | None) – The exception type if an exception occurred, or None. * exc\_value ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException) | None) – The exception instance if an exception occurred, or None. * traceback ([TracebackType](https://docs.python.org/3/library/types.html#types.TracebackType) | None) – The traceback object if an exception occurred, or None.
**Return type:**
None
### `push()` {#max.profiler.Tracer.push} > push(message=None, color='modular\_purple') Pushes a new profiling span onto the stack. Creates and activates a new profiling span. If profiling is disabled or no message is provided, pushes a None placeholder to maintain stack consistency.
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The name of the profiling span. If None, no span is created. * color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Return type:**
None
## `traced()` {#max.profiler.traced} > max.profiler.traced(func=None, \*, message=None, color='modular\_purple') Decorator for creating a profiling span for a function. Creates a profiling span that measures the execution time of the decorated function. This is useful for identifying performance bottlenecks without modifying the function’s internal code. The decorator supports both synchronous and asynchronous functions. ```python from max.profiler import traced # Decorator with custom span name @traced(message="inference", color="red") def run_model() -> None: # The profiling span is named "inference" model.execute() # Decorator with default span name (uses function name) @traced def preprocess_data() -> None: # The profiling span is named "preprocess_data" data.normalize() ```
**Parameters:**
* func (\_FuncType | None) – The function to profile. * message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The name of the profiling span. If None, uses the function name. * color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Returns:**
The decorated function wrapped in a trace object.
**Return type:**
[Callable](graph/ops.md#max.graph.ops.Callable)
--- ## random Provides random tensor generation utilities. This module provides functions for generating random tensors with various distributions. All functions support specifying data type and device, with sensible defaults based on the target device. You can generate random tensors using different distributions: ```default from max import random from max.dtype import DType from max.driver import CPU tensor1 = random.uniform((2, 3), dtype=DType.float32, device=CPU()) tensor2 = random.uniform((4, 4), range=(0, 1), dtype=DType.float32, device=CPU()) ``` ## `gaussian()` {#max.random.gaussian} > max.random.gaussian(shape=(), mean=0.0, std=1.0, \*, dtype=None, device=None) Creates a tensor filled with random values from a Gaussian (normal) distribution. Generates a tensor with values sampled from a normal (Gaussian) distribution with the specified mean and standard deviation. This is commonly used for weight initialization using techniques like Xavier/Glorot or He initialization. Create tensors with random values from a Gaussian distribution: ```default from max import random from max.driver import CPU from max.dtype import DType # Standard normal distribution tensor = random.gaussian((2, 3), dtype=DType.float32, device=CPU()) ```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Defaults to scalar (empty tuple). * mean ([float](https://docs.python.org/3/library/functions.html#float)) – The mean (center) of the Gaussian distribution. This determines where the distribution is centered. Defaults to `0.0`. * std ([float](https://docs.python.org/3/library/functions.html#float)) – The standard deviation (spread) of the Gaussian distribution. Must be positive. Larger values create more spread in the distribution. Defaults to `1.0`. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type of the output tensor. If `None`, uses the default dtype for the specified device (float32 for CPU, bfloat16 for accelerators). Defaults to `None`. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If `None`, uses the default device (accelerator if available, otherwise CPU). Defaults to `None`.
**Returns:**
A [`Tensor`](tensor.md#max.tensor.Tensor) with random values sampled from the Gaussian distribution.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If std <= 0.
## `normal()` {#max.random.normal} > max.random.normal(shape=(), mean=0.0, std=1.0, \*, dtype=None, device=None) Alias for [`gaussian()`](#max.random.gaussian). Creates a tensor with values from a normal (Gaussian) distribution.
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) * mean ([float](https://docs.python.org/3/library/functions.html#float)) * std ([float](https://docs.python.org/3/library/functions.html#float)) * dtype ([DType](dtype.md#max.dtype.DType) | None) * device ([Device](driver.md#max.driver.Device) | None)
## `seed()` {#max.random.seed} > max.random.seed() Gets the global random seed tensor. Returns the global seed tensor used for random number generation in eager execution mode. Creates the seed tensor on first access, initialized with the dtype, shape, and device specified by `ops.random.SeedType`.
**Returns:**
The global seed tensor for random number generation.
**Return type:**
[Tensor](tensor.md#max.tensor.Tensor)
## `set_seed()` {#max.random.set_seed} > max.random.set\_seed(value) Sets the global random seed value. Updates the global random seed to the specified value. This affects all subsequent random number generation in eager execution mode.
**Parameters:**
value ([int](https://docs.python.org/3/library/functions.html#int)) – The integer seed value to set.
**Return type:**
None
## `uniform()` {#max.random.uniform} > max.random.uniform(shape=(), range=(0, 1), \*, dtype=None, device=None) Creates a tensor filled with random values from a uniform distribution. Generates a tensor with values uniformly distributed between the specified minimum and maximum bounds. This is useful for initializing weights, generating random inputs, or creating noise. Create tensors with uniform random values: ```default from max import random from max.dtype import DType from max.driver import CPU # Generate 2x3 tensor with values between 0 and 1 tensor1 = random.uniform((2, 3), dtype=DType.float32, device=CPU()) tensor2 = random.uniform((4, 4), range=(0, 1), dtype=DType.float32, device=CPU()) ```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Defaults to scalar (empty tuple). * range ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[float](https://docs.python.org/3/library/functions.html#float), [float](https://docs.python.org/3/library/functions.html#float)]) – A tuple specifying the (min, max) bounds of the uniform distribution. The minimum value is inclusive, the maximum value is exclusive. Defaults to `(0, 1)`. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type of the output tensor. If `None`, uses the default dtype for the specified device (float32 for CPU, bfloat16 for accelerators). Defaults to `None`. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If `None`, uses the default device (accelerator if available, otherwise CPU). Defaults to `None`.
**Returns:**
A [`Tensor`](tensor.md#max.tensor.Tensor) with random values sampled from the uniform distribution.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the range tuple does not contain exactly two values or if min >= max.
--- ## tensor (Python) Provides tensor operations with eager execution capabilities. This module provides the [`Tensor`](#max.tensor.Tensor) class which supports eager execution of tensor operations, complementing the graph-based execution model provided by `graph`. The tensor operations automatically compile and execute using the MAX runtime. **Key Features:** * **Eager semantics**: Operations give immediate results for quick iteration and feedback. * **High performance**: All operations use high-performance Mojo implementations compiled specifically for the available hardware. * **Automatic compilation**: Tensors are compiled and optimized automatically. Operations may be easily fused into larger graphs to take advantage of the graph compiler’s automatic fusions. * **Lazy evaluation**: Tensors may be computed lazily until their values are needed. * **Familiar API**: Supports common array operations and indexing. :::note Note Tensors use lazy evaluation and JIT compilation, which incurs compilation overhead on first execution. This can result in higher latency for initial operations compared to eager frameworks like NumPy or PyTorch. Subsequent executions reuse compiled kernels for better performance. ::: Create and manipulate tensors with automatic compilation and optimization: ```python from max.tensor import Tensor from max.driver import CPU from max.dtype import DType x = Tensor.ones((2, 3), dtype=DType.float32, device=CPU()) y = Tensor.zeros_like(x) result = x + y # Eager execution with automatic compilation ``` Operations may be combined into a single execution graph to take advantage of automatic kernel fusion: ```python from max import functional as F @F.functional def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: return x @ weight.T + bias # Create and operate on tensors x = Tensor.ones([2, 3]) weight = Tensor.ones([6, 3]) bias = Tensor.ones([6]) # Eager execution with a single fused graph result = linear(x, weight, bias) ``` Users may opt in to lazy execution. This is primarily useful for 1. Operations which may never execute, for instance creating modules with randomly initialized weights before loading weights 2. Combining many operations into a single execution ```python from max.nn import Linear with F.lazy(): model = Linear(2, 3) print(model) # Lazy weights not initialized # Load pretrained weights weights = { "weight": Tensor.zeros([3, 2]), "bias": Tensor.zeros([3]), } model.load_state_dict(weights) # Or compile directly without ever initializing weights from max.graph import TensorType input_type = TensorType(DType.float32, ["batch", 2], CPU()) model = model.compile(input_type, weights=weights) ``` ## `RealizationContext` {#max.tensor.RealizationContext} > class max.tensor.RealizationContext(\*args, \*\*kwargs) Implements a way to realize unrealized tensors. Most users should never have to think about the existence of this type. It exists to facilitate optimizations around where and when tensor operations are executed. * Each tensor is either real or associated with a RealizationContext. * If a tensor is not real, ie. “unrealized”, then it is backed by some symbolic computation. * The RealizationContext is responsible for tracking this symbolic computation and “realizing” the tensor (executing the computation and backing the tensor with real data) if and when it is asked to do so. * A RealizationContext can only realize tensors associated with it. RealizationContext abstracts over various semantics of tensor construction. **“Eager” execution**: tensors are realized as soon as the realization context exits. This is the default behavior. This has a huge concrete advantage over eagerly executing one operation at a time: by controlling the boundary of where the eager context starts and ends, we can give advanced users a tool to \_enable fine-grained bounds for automatic fusion! In practice the easiest way to do this is to mark a function as F.functional. This function is then assumed to be “atomic” for the purposes of eager execution. All ops within the function execute as part of the same graph, meaning the compiler is free to fuse operations and generate fused kernels within this region. **“Lazy” execution**: tensors are realized only when code later tries to use them. This enables a class of interface design common in the ML world, in which layers are constructed with randomized weights which are never used. Lazy execution neatly allows constructing entire models, only performing the weight initialization and allocating memory for them if and when those weights are actually used. **Graph compilation**: tensors may never be realized. This allows tensor operations to be composed with direct usage of the Graph API, for instance Module.compile, or using F.\* operations in another Graph API usage. **Async execution**: Tensors are realized as async functions, allowing clean integration in async systems like web services. ### `add_source()` {#max.tensor.RealizationContext.add_source} > add\_source(tensor) Adds a realized tensor as a “source” of the realization state, ie. one on whose values unrealized tensors depend.
**Parameters:**
tensor ([Tensor](#max.tensor.Tensor)) – The realized tensor to add as a source to the computation.
**Returns:**
A realization state for the tensor. This may be used to compute downstream unrealized values. \_If it is used in any mutating operations, it should be assigned to tensor.state to mark the tensor as having been mutated.
**Return type:**
[RealizationState](#max.tensor.RealizationState)
### `create_unrealized()` {#max.tensor.RealizationContext.create_unrealized} > create\_unrealized(value) Registers an unrealized graph value with the realization context and returns it as an unrealized tensor.
**Parameters:**
value ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) – The graph value representing the result of a computation.
**Returns:**
A new tensor associated with the unrealized value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `graph` {#max.tensor.RealizationContext.graph} > graph: [Graph](graph/Graph.md#max.graph.Graph) The graph used by the realization context. ### `realize_all()` {#max.tensor.RealizationContext.realize_all} > async realize\_all() Realizes all unrealized tensors associated with this context.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Tensor](#max.tensor.Tensor)]
## `RealizationState` {#max.tensor.RealizationState} > class max.tensor.RealizationState(value, ctx) State for an unrealized tensor. See [`RealizationContext`](#max.tensor.RealizationContext).
**Parameters:**
* value ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) * ctx ([RealizationContext](#max.tensor.RealizationContext))
### `ctx` {#max.tensor.RealizationState.ctx} > ctx: [RealizationContext](#max.tensor.RealizationContext) The realization context used to create this tensor. This context is responsible for realizing the tensor to a real value. ### `value` {#max.tensor.RealizationState.value} > value: [BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) The symbolic value representing the computation backing this tensor. ## `Tensor` {#max.tensor.Tensor} > class max.tensor.Tensor(\*, storage=None, state=None) A multi-dimensional array with eager execution and automatic compilation. The Tensor class provides a high-level interface for numerical computations with automatic compilation and optimization via the MAX runtime. Operations on tensors execute eagerly while benefiting from lazy evaluation and graph-based optimizations behind the scenes. **Key Features:** * **Eager execution**: Operations execute immediately with automatic compilation. * **Lazy evaluation**: Computation may be deferred until results are needed. * **High performance**: Uses the Mojo compiler and optimized kernels. * **Familiar API**: Supports common array operations and indexing. * **Device flexibility**: Works seamlessly across CPU and accelerators. **Creating Tensors:** Create tensors using factory methods like [`ones()`](#max.tensor.Tensor.ones), [`zeros()`](#max.tensor.Tensor.zeros), [`constant()`](#max.tensor.Tensor.constant), [`arange()`](#max.tensor.Tensor.arange), or from other array libraries via [`from_dlpack()`](#max.tensor.Tensor.from_dlpack). ```python from max import tensor from max.dtype import DType # Create tensors with factory methods x = tensor.Tensor.ones((2, 3), dtype=DType.float32) y = tensor.Tensor.zeros((2, 3), dtype=DType.float32) # Perform operations result = x + y # Eager execution with automatic compilation # Access values print(result.shape) # (2, 3) print(result.dtype) # DType.float32 ``` **Implementation Notes:** Tensors use lazy evaluation internally - they don’t always hold concrete data in memory. A tensor may be “unrealized” (not yet computed) until its value is actually needed (e.g., when converting to other formats or calling [`item()`](#max.tensor.Tensor.item)). This allows the runtime to optimize sequences of operations efficiently. Operations on tensors build a computation graph behind the scenes, which is compiled and executed when needed. All illegal operations fail immediately with clear error messages, ensuring a smooth development experience. :::note Note The lazy evaluation model and JIT compilation introduce compilation overhead on first execution of operations. This results in higher latency for interactive operations compared to eager frameworks like NumPy or PyTorch, particularly when materializing tensor values (e.g., printing or converting to other formats). Subsequent operations on similar shapes and dtypes reuse compiled kernels for improved performance. ::: **Interoperability:** Tensors support the DLPack protocol for zero-copy data exchange with NumPy, PyTorch, JAX, and other array libraries. Use [`from_dlpack()`](#max.tensor.Tensor.from_dlpack) to import arrays and standard DLPack conversion for export.
**Parameters:**
* storage ([Buffer](driver.md#max.driver.Buffer) | None) * state ([RealizationState](#max.tensor.RealizationState) | None)
### `T` {#max.tensor.Tensor.T} > property T: [Tensor](#max.tensor.Tensor) Returns a tensor with the last two dimensions transposed. This is equivalent to calling `transpose(-1, -2)`, which swaps the last two dimensions of the tensor. For a 2D matrix, this produces the standard matrix transpose. ```python from max.tensor import Tensor from max.dtype import DType # Create a 2x3 matrix x = Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32) print(f"Original shape: {x.shape}") # Output: Original shape: [Dim(2), Dim(3)] # Use .T property (equivalent to transpose(-1, -2)) y = x.T print(f"Transposed shape: {y.shape}") # Output: Transposed shape: [Dim(3), Dim(2)] print(y) ```
**Returns:**
A tensor with the last two dimensions transposed.
### `arange()` {#max.tensor.Tensor.arange} > classmethod arange(start=0, stop=None, step=1, \*, dtype=None, device=None) Creates a tensor with evenly spaced values within a given interval. Returns a new 1D tensor containing a sequence of values starting from `start` (inclusive) and ending before `stop` (exclusive), with values spaced by `step`. This is similar to Python’s built-in `range()` function and NumPy’s `arange()`. ```python from max import tensor from max.dtype import DType # Create a range from 0 to 10 (exclusive) x = tensor.Tensor.arange(10) # Result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # Create a range from 5 to 15 with step 2 y = tensor.Tensor.arange(5, 15, 2) # Result: [5, 7, 9, 11, 13] # Use a specific dtype z = tensor.Tensor.arange(0, 5, dtype=DType.float32) # Result: [0.0, 1.0, 2.0, 3.0, 4.0] # Create a range with float step (like numpy/pytorch) w = tensor.Tensor.arange(0.0, 1.0, 0.2, dtype=DType.float32) # Result: [0.0, 0.2, 0.4, 0.6, 0.8] # Create a descending range with negative step v = tensor.Tensor.arange(5, 0, -1, dtype=DType.float32) # Result: [5.0, 4.0, 3.0, 2.0, 1.0] ```
**Parameters:**
* start (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The starting value of the sequence. If `stop` is not provided, this becomes the `stop` value and `start` defaults to 0. * stop (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The end value of the sequence (exclusive). If not specified, the sequence ends at `start` and begins at 0. * step (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The spacing between values in the sequence. Must be non-zero. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified, defaults to `DType.float32` for CPU devices and `DType.bfloat16` for accelerator devices. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A 1D tensor containing the evenly spaced values.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `argmax()` {#max.tensor.Tensor.argmax} > argmax(axis=-1) Finds the indices of the maximum values along an axis. Returns a tensor containing the indices of the maximum values along the specified axis. This is useful for finding the position of the largest element, such as determining predicted classes in classification. ```python from max import tensor from max.dtype import DType # Create a 2x4 tensor x = tensor.Tensor.constant( [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32 ) # Find argmax along last axis (within each row) indices = x.argmax(axis=-1) # Result: [1, 2] (index 1 in first row, index 2 in second row) # Find argmax over all elements index = x.argmax(axis=None) # Result: 6 (flattened index of maximum value 4.2) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the maximum indices. Defaults to -1 (the last axis). If None, finds the index of the maximum value across all elements.
**Returns:**
A tensor containing the indices of the maximum values.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `broadcast_to()` {#max.tensor.Tensor.broadcast_to} > broadcast\_to(shape) Broadcasts the tensor to the specified shape. Returns a tensor broadcast to the target shape, following NumPy broadcasting semantics. Dimensions of size 1 in the input can be expanded to match larger dimensions in the target shape. This is equivalent to PyTorch’s `torch.broadcast_to()` and `torch.Tensor.expand()`. ```python from max import tensor from max.dtype import DType # Create a tensor with shape (3, 1) x = tensor.Tensor.ones([3, 1], dtype=DType.float32) # Broadcast to (3, 4) - expands the second dimension y = x.broadcast_to([3, 4]) print(y.shape) # (3, 4) # Add a new leading dimension w = x.broadcast_to([2, 3, 1]) print(w.shape) # (2, 3, 1) ```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The target shape. Each dimension must either match the input dimension or be broadcastable from size 1.
**Returns:**
A tensor broadcast to the specified shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `cast()` {#max.tensor.Tensor.cast} > cast(dtype) Casts the tensor to a different data type. Returns a new tensor with the same values but a different data type. This is useful for type conversions between different numeric types, such as converting `float32` to `int32` for indexing operations or `float32` to `bfloat16` for memory-efficient computations. ```python from max import tensor from max.dtype import DType # Create a float32 tensor x = tensor.Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32) print(x.dtype) # DType.float32 # Cast to int32 (truncates decimal values) y = x.cast(DType.int32) print(y.dtype) # DType.int32 # Values: [1, 2, 3] ```
**Parameters:**
dtype ([DType](dtype.md#max.dtype.DType)) – The target data type for the tensor.
**Returns:**
A new tensor with the specified data type.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `clip()` {#max.tensor.Tensor.clip} > clip(\*, min=None, max=None) Clips values outside a range to the boundaries of the range. ```python from max import tensor # Create a 2x4 tensor x = tensor.Tensor.constant( [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]] ) # Find max along last axis (within each row) clipped_above = x.clip(max=3.) # Result: [[1.2, 3., 2.1, 0.8], [2.3, 1.9, 3, 3.]] clipped_below = x.clip(min=3.) # Result: [[3., 3.5, 3., 3.], [3., 3., 4.2, 3.]] ```
**Parameters:**
* min (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The minimum value of the range. If not specified, do not clip values for being too small. * max (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The maximum value of the range. If not specified, do not clip values for being too large.
**Returns:**
A tensor containing the values clipped to the specified range.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `constant()` {#max.tensor.Tensor.constant} > classmethod constant(value, \*, dtype=None, device=None) Creates a constant tensor from a scalar, array, or nested list. Constructs a tensor with constant values that can be a scalar, a nested Python list, or a DLPack-compatible array. The shape is automatically inferred from the input data structure. ```python from max import tensor from max.dtype import DType # Create from scalar x = tensor.Tensor.constant(42, dtype=DType.int32) # Create from nested list y = tensor.Tensor.constant([[1.0, 2.0], [3.0, 4.0]]) # Create from NumPy array import numpy as np z = tensor.Tensor.constant(np.array([1, 2, 3])) ```
**Parameters:**
* value ([DLPackArray](driver.md#max.driver.DLPackArray) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Number | NestedArray]] | [float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The constant value for the tensor. Can be a scalar number, a nested Python list, or any DLPack-compatible array. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified, defaults to `DType.float32` for CPU devices and `DType.bfloat16` for accelerator devices. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor containing the constant value(s).
**Return type:**
[Tensor](#max.tensor.Tensor)
### `device` {#max.tensor.Tensor.device} > property device: [Device](driver.md#max.driver.Device) Gets the device where the tensor is stored. Returns the device (CPU or accelerator) where the tensor’s data is located.
**Returns:**
The device where the tensor is stored.
**Return type:**
[Device](driver.md#max.driver.Device)
### `driver_tensor` {#max.tensor.Tensor.driver_tensor} > property driver\_tensor: [Buffer](driver.md#max.driver.Buffer) A pointer to the underlying memory. Raises if the tensor is unrealized. ### `dtype` {#max.tensor.Tensor.dtype} > property dtype: [DType](dtype.md#max.dtype.DType) Gets the data type of the tensor elements. Returns the data type (dtype) of the elements stored in the tensor, such as `float32`, `int32`, or `bfloat16`.
**Returns:**
The data type of the tensor elements.
**Return type:**
[DType](dtype.md#max.dtype.DType)
### `from_dlpack()` {#max.tensor.Tensor.from_dlpack} > classmethod from\_dlpack(array) Creates a tensor from a DLPack array. Constructs a tensor by importing data from any object that supports the DLPack protocol (such as NumPy arrays and PyTorch tensors). This enables zero-copy interoperability with other array libraries. ```python import numpy as np from max import tensor # Create a NumPy array np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) # Convert to MAX tensor via DLPack x = tensor.Tensor.from_dlpack(np_array) ```
**Parameters:**
array ([DLPackArray](driver.md#max.driver.DLPackArray)) – Any object supporting the DLPack protocol, such as NumPy arrays, PyTorch tensors, or JAX arrays.
**Returns:**
A new tensor containing the data from the DLPack array.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `from_graph_value()` {#max.tensor.Tensor.from_graph_value} > classmethod from\_graph\_value(value) Creates a tensor from a graph value. Constructs a tensor from an existing graph value, which can be either a [`TensorValue`](graph/TensorValue.md#max.graph.TensorValue) or [`BufferValue`](graph/BufferValue.md#max.graph.BufferValue). This is used for converting graph level values into tensor objects. The new tensor is registered as unrealized, backed by the current realization context.
**Parameters:**
value ([Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The graph value to wrap. Can be either a TensorValue or BufferValue from the MAX graph API.
**Returns:**
A new tensor backed by the provided graph value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `full()` {#max.tensor.Tensor.full} > classmethod full(shape, value, \*, dtype=None, device=None) Creates a tensor filled with a specified value. Returns a new tensor with the given shape where all elements are initialized to the specified value. This is useful for creating tensors with uniform values other than zero or one. ```python from max import tensor from max.dtype import DType # Create a 3x3 tensor filled with 7 x = tensor.Tensor.full((3, 3), value=7, dtype=DType.int32) # Create a 2x4 tensor filled with pi y = tensor.Tensor.full((2, 4), value=3.14159) ```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape. * value ([float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The scalar value to fill the tensor with. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified, defaults to `DType.float32` for CPU devices and `DType.bfloat16` for accelerator devices. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with the given value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `full_like()` {#max.tensor.Tensor.full_like} > classmethod full\_like(input, value) Creates a tensor filled with a value, matching a given tensor’s properties. Returns a new tensor filled with the specified value that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s `full_like` and PyTorch’s `full_like`. ```python from max import tensor from max.dtype import DType # Create a reference tensor ref = tensor.Tensor.ones([2, 3], dtype=DType.float32) # Create tensor filled with 5.0 matching the reference tensor x = tensor.Tensor.full_like(ref, value=5.0) ```
**Parameters:**
* input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input. * value ([float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The scalar value to fill the tensor with.
**Returns:**
A new tensor filled with the specified value, matching the properties of the input.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `item()` {#max.tensor.Tensor.item} > item() Gets the scalar value from a single-element tensor. Extracts and returns the scalar value from a tensor containing exactly one element. The tensor is realized if needed and transferred to CPU before extracting the value.
**Returns:**
The scalar value from the tensor. The return type matches the tensor’s dtype (e.g., float for float32, int for int32).
**Raises:**
[TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If the tensor contains more than one element.
### `max()` {#max.tensor.Tensor.max} > max(axis=-1) Computes the maximum values along an axis. Returns a tensor containing the maximum values along the specified axis. This is useful for reduction operations and finding peak values in data. ```python from max import tensor from max.dtype import DType # Create a 2x4 tensor x = tensor.Tensor.constant( [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32 ) # Find max along last axis (within each row) row_max = x.max(axis=-1) # Result: [3.5, 4.2] # Find max along first axis (within each column) col_max = x.max(axis=0) # Result: [2.3, 3.5, 4.2, 3.1] # Find max over all elements overall_max = x.max(axis=None) # Result: 4.2 (maximum value across all elements) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the maximum. Defaults to -1 (the last axis). If None, computes the maximum across all elements.
**Returns:**
A tensor containing the maximum values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `mean()` {#max.tensor.Tensor.mean} > mean(axis=-1) Computes the mean values along an axis. Returns a tensor containing the arithmetic mean of values along the specified axis. This is useful for computing averages, normalizing data, or aggregating statistics. ```python from max import tensor from max.dtype import DType # Create a 2x4 tensor x = tensor.Tensor.constant( [[2.0, 4.0, 6.0, 8.0], [1.0, 3.0, 5.0, 7.0]], dtype=DType.float32 ) # Compute mean along last axis (within each row) row_mean = x.mean(axis=-1) # Result: [5.0, 4.0] (mean of each row) # Compute mean along first axis (within each column) col_mean = x.mean(axis=0) # Result: [1.5, 3.5, 5.5, 7.5] (mean of each column) # Compute mean over all elements overall_mean = x.mean(axis=None) # Result: 4.5 (mean of all elements) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the mean. Defaults to -1 (the last axis). If None, computes the mean across all elements.
**Returns:**
A tensor containing the mean values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `min()` {#max.tensor.Tensor.min} > min(axis=-1) Computes the minimum values along an axis. Returns a tensor containing the minimum values along the specified axis. This is useful for reduction operations and finding the smallest values in data. ```python from max import tensor from max.dtype import DType # Create a 2x4 tensor x = tensor.Tensor.constant( [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32 ) # Find min along last axis (within each row) row_min = x.min(axis=-1) # Result: [0.8, 1.9] # Find min along first axis (within each column) col_min = x.min(axis=0) # Result: [1.2, 1.9, 2.1, 0.8] # Find min over all elements overall_min = x.min(axis=None) # Result: 0.8 (minimum value across all elements) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the minimum. Defaults to -1 (the last axis). If None, computes the minimum across all elements.
**Returns:**
A tensor containing the minimum values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `num_elements()` {#max.tensor.Tensor.num_elements} > num\_elements() Gets the total number of elements in the tensor. Computes the product of all dimensions in the tensor’s shape to determine the total number of elements.
**Returns:**
The total number of elements in the tensor.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `ones()` {#max.tensor.Tensor.ones} > classmethod ones(shape, \*, dtype=None, device=None) Creates a tensor filled with ones. Returns a new tensor with the specified shape where all elements are initialized to one. The tensor is created with eager execution and automatic compilation. ```python from max import tensor from max.driver import CPU from max.dtype import DType # Create a 2x3 tensor of ones x = tensor.Tensor.ones((2, 3), dtype=DType.float32, device=CPU()) # Result: [[1.0, 1.0, 1.0], # [1.0, 1.0, 1.0]] # Create a 1D tensor using default dtype and device y = tensor.Tensor.ones((5,)) ```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified, defaults to `DType.float32` for CPU devices and `DType.bfloat16` for accelerator devices. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with ones.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `ones_like()` {#max.tensor.Tensor.ones_like} > classmethod ones\_like(input) Creates a tensor of ones matching a given tensor’s properties. Returns a new tensor filled with ones that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s `ones_like` and PyTorch’s `ones_like`. ```python from max import tensor from max.dtype import DType # Create a reference tensor ref = tensor.Tensor.zeros([3, 4], dtype=DType.float32) # Create ones tensor matching the reference tensor x = tensor.Tensor.ones_like(ref) # Result: 3x4 tensor of ones with dtype float32 ```
**Parameters:**
input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
**Returns:**
A new tensor filled with ones matching the properties of the input.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `permute()` {#max.tensor.Tensor.permute} > permute(dims) Permutes the dimensions of the tensor. Returns a tensor with its dimensions reordered according to the specified permutation. This is useful for changing the layout of multi-dimensional data, such as converting between different tensor layout conventions (e.g., from `[batch, channels, height, width]` to `[batch, height, width, channels]`). ```python from max.tensor import Tensor from max.dtype import DType # Create a 3D tensor (batch_size=2, channels=3, length=4) x = Tensor.constant([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=DType.int32) print(f"Original shape: {x.shape}") # Output: Original shape: [Dim(2), Dim(3), Dim(4)] # Rearrange to (batch, length, channels) y = x.permute([0, 2, 1]) print(f"Permuted shape: {y.shape}") # Output: Permuted shape: [Dim(2), Dim(4), Dim(3)] ```
**Parameters:**
dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – A list specifying the new order of dimensions. For example, `[2, 0, 1]` moves dimension 2 to position 0, dimension 0 to position 1, and dimension 1 to position 2.
**Returns:**
A tensor with permuted dimensions.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `range_like()` {#max.tensor.Tensor.range_like} > classmethod range\_like(type) Creates a range tensor matching a given type’s properties. Returns a new tensor containing sequential indices along the last dimension, broadcasted to match the shape of the specified tensor type. Each row (along the last dimension) contains values from 0 to the dimension size minus one. This is useful for creating position indices or coordinate tensors. ```python from max import tensor from max.graph import TensorType from max.driver import CPU from max.dtype import DType # Create a reference tensor type with shape (2, 4) ref_type = TensorType(DType.int32, (2, 4), device=CPU()) # Create range tensor matching the reference type x = tensor.Tensor.range_like(ref_type) # Result: [[0, 1, 2, 3], # [0, 1, 2, 3]] ```
**Parameters:**
type ([TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor type to match. The returned tensor will have the same shape, dtype, and device as this type, with values representing indices along the last dimension.
**Returns:**
A new tensor with sequential indices broadcasted to match the input type’s shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `rank` {#max.tensor.Tensor.rank} > property rank: [int](https://docs.python.org/3/library/functions.html#int) Gets the number of dimensions in the tensor. Returns the rank (number of dimensions) of the tensor. For example, a scalar has rank 0, a vector has rank 1, and a matrix has rank 2.
**Returns:**
The number of dimensions in the tensor.
**Return type:**
[int](https://docs.python.org/3/library/functions.html#int)
### `real` {#max.tensor.Tensor.real} > property real: [bool](https://docs.python.org/3/library/functions.html#bool) ### `realize` {#max.tensor.Tensor.realize} > property realize: [Tensor](#max.tensor.Tensor) Force the tensor to realize if it is not already. ### `reshape()` {#max.tensor.Tensor.reshape} > reshape(shape) Reshapes the tensor to a new shape. Returns a tensor with the same data but a different shape. The total number of elements must remain the same. This is useful for changing tensor dimensions for different operations, such as flattening a multi-dimensional tensor or converting a 1D tensor into a matrix. ```python from max import tensor from max.dtype import DType # Create a 2x3 tensor x = tensor.Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32) print(x.shape) # (2, 3) # Flatten to 1D y = x.reshape((6,)) print(y.shape) # (6,) # Values: [1, 2, 3, 4, 5, 6] ```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The desired output shape. Can be a tuple or list of integers. The total number of elements must equal the original tensor’s element count.
**Returns:**
A reshaped tensor with the specified shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `shape` {#max.tensor.Tensor.shape} > property shape: [Shape](graph/shape.md#max.graph.shape.Shape) Gets the shape of the tensor. Returns the dimensions of the tensor as a shape object.
**Returns:**
The shape of the tensor.
**Return type:**
[Shape](graph/shape.md#max.graph.shape.Shape)
### `split()` {#max.tensor.Tensor.split} > split(split\_size\_or\_sections, axis=0) Splits the tensor into multiple tensors along a given dimension. This method supports two modes, matching PyTorch’s behavior: * If `split_size_or_sections` is an **int**, splits into chunks of that size (the last chunk may be smaller if not evenly divisible). * If `split_size_or_sections` is a **list of ints**, splits into chunks with exactly those sizes (must sum to the dimension size). ```python from max import tensor from max.dtype import DType # Create a 10x4 tensor x = tensor.Tensor.ones([10, 4], dtype=DType.float32) # Split into chunks of size 3 (last chunk is size 1) chunks = x.split(3, axis=0) # Result: 4 tensors with shapes [3,4], [3,4], [3,4], [1,4] # Split into exact sizes chunks = x.split([2, 3, 5], axis=0) # Result: 3 tensors with shapes [2,4], [3,4], [5,4] ```
**Parameters:**
* split\_size\_or\_sections ([int](https://docs.python.org/3/library/functions.html#int) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Either an int (chunk size) or a list of ints (exact sizes for each output tensor). * axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension along which to split. Defaults to 0.
**Returns:**
A list of tensors resulting from the split.
**Return type:**
[list](https://docs.python.org/3/library/stdtypes.html#list)\[[Tensor](#max.tensor.Tensor)]
### `squeeze()` {#max.tensor.Tensor.squeeze} > squeeze(axis) Removes a size-1 dimension from the tensor. Returns a tensor with the specified size-1 dimension removed. This is useful for removing singleton dimensions from tensors after operations that may have added them. ```python from max import tensor from max.dtype import DType # Create a tensor with a size-1 dimension x = tensor.Tensor.ones([4, 1, 6], dtype=DType.float32) print(x.shape) # (4, 1, 6) # Squeeze out the size-1 dimension y = x.squeeze(axis=1) print(y.shape) # (4, 6) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension to remove from the tensor’s shape. If negative, this indexes from the end of the tensor. The dimension at this axis must have size 1.
**Returns:**
A tensor with the specified dimension removed.
**Return type:**
[Tensor](#max.tensor.Tensor)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the dimension at the specified axis is not size 1.
### `state` {#max.tensor.Tensor.state} > state: [RealizationState](#max.tensor.RealizationState) | [None](https://docs.python.org/3/library/constants.html#None) State for realizing an unrealized tensor. ### `storage` {#max.tensor.Tensor.storage} > storage: [Buffer](driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) Underlying memory for a realized tensor. If the tensor is used in any mutating operations that have not been realized, this holds the state before any updates. ### `sum()` {#max.tensor.Tensor.sum} > sum(axis=-1) Computes the sum of values along an axis. Returns a tensor containing the sum of values along the specified axis. This is a fundamental reduction operation used for aggregating data, computing totals, and implementing other operations like mean. ```python from max import tensor from max.dtype import DType # Create a 2x3 tensor x = tensor.Tensor.constant( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=DType.float32 ) # Sum along last axis (within each row) row_sum = x.sum(axis=-1) # Result: [6.0, 15.0] (sum of each row) # Sum along first axis (within each column) col_sum = x.sum(axis=0) # Result: [5.0, 7.0, 9.0] (sum of each column) # Sum over all elements total = x.sum(axis=None) # Result: 21.0 (sum of all elements) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the sum. Defaults to -1 (the last axis). If None, computes the sum across all elements.
**Returns:**
A tensor containing the sum along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `to()` {#max.tensor.Tensor.to} > to(device) Transfers the tensor to a different device. Creates a new tensor with the same data on the specified device. This allows moving tensors between CPU and accelerators or between different accelerator devices. ```python from max import tensor from max.driver import CPU, Accelerator # Create a tensor on CPU x = tensor.Tensor.ones((2, 3), device=CPU()) print(x.device) # CPU # Transfer to accelerator y = x.to(Accelerator()) print(y.device) # Accelerator(0) ```
**Parameters:**
device ([Device](driver.md#max.driver.Device)) – The target device for the tensor.
**Returns:**
A new tensor with the same data on the specified device.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `transpose()` {#max.tensor.Tensor.transpose} > transpose(dim1, dim2) Returns a tensor that is a transposed version of input. The given dimensions `dim1` and `dim2` are swapped. ```python from max.tensor import Tensor from max.dtype import DType # Create a 2x3 matrix x = Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32) print(f"Original shape: {x.shape}") # Output: Original shape: [Dim(2), Dim(3)] print(x) # Transpose dimensions 0 and 1 to get a 3x2 matrix y = x.transpose(0, 1) print(f"Transposed shape: {y.shape}") # Output: Transposed shape: [Dim(3), Dim(2)] print(y) ```
**Parameters:**
* dim1 ([int](https://docs.python.org/3/library/functions.html#int)) – The first dimension to be transposed. * dim2 ([int](https://docs.python.org/3/library/functions.html#int)) – The second dimension to be transposed.
**Returns:**
A tensor with dimensions `dim1` and `dim2` swapped.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `type` {#max.tensor.Tensor.type} > property type: [TensorType](graph/type.md#max.graph.type.TensorType) Gets the tensor type information. Returns the type information for the tensor, including shape, dtype, and device. If the underlying value is a buffer type, it’s converted to a tensor type.
**Returns:**
The type information for the tensor.
**Return type:**
[TensorType](graph/ops.md#max.graph.ops.TensorType)
### `unsqueeze()` {#max.tensor.Tensor.unsqueeze} > unsqueeze(axis) Inserts a size-1 dimension into the tensor. Returns a tensor with a new size-1 dimension inserted at the specified position. This is the inverse of [`squeeze()`](#max.tensor.Tensor.squeeze) and is useful for adding dimensions needed for broadcasting or matrix operations. ```python from max import tensor from max.dtype import DType # Create a 1D tensor x = tensor.Tensor.constant([1.0, 2.0, 3.0], dtype=DType.float32) print(x.shape) # (3,) # Add dimension at the end y = x.unsqueeze(axis=-1) print(y.shape) # (3, 1) # Add dimension at the beginning z = x.unsqueeze(axis=0) print(z.shape) # (1, 3) ```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The index at which to insert the new dimension. If negative, indexes relative to 1 plus the rank of the tensor. For example, `axis=-1` adds a dimension at the end.
**Returns:**
A tensor with an additional size-1 dimension.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `zeros()` {#max.tensor.Tensor.zeros} > classmethod zeros(shape, \*, dtype=None, device=None) Creates a tensor filled with zeros. Returns a new tensor with the specified shape where all elements are initialized to zero. The tensor is created with eager execution and automatic compilation. ```python from max import tensor from max.driver import CPU from max.dtype import DType # Create a 2x3 tensor of zeros x = tensor.Tensor.zeros((2, 3), dtype=DType.float32, device=CPU()) # Result: [[0.0, 0.0, 0.0], # [0.0, 0.0, 0.0]] # Create a 1D tensor using default dtype and device y = tensor.Tensor.zeros((5,)) ```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape. * dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified, defaults to `DType.float32` for CPU devices and `DType.bfloat16` for accelerator devices. * device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with zeros.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `zeros_like()` {#max.tensor.Tensor.zeros_like} > classmethod zeros\_like(input) Creates a tensor of zeros matching a given tensor’s properties. Returns a new tensor filled with zeros that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s `zeros_like` and PyTorch’s `zeros_like`. ```python from max import tensor from max.dtype import DType # Create a reference tensor ref = tensor.Tensor.ones([3, 4], dtype=DType.float32) # Create zeros tensor matching the reference tensor x = tensor.Tensor.zeros_like(ref) # Result: 3x4 tensor of zeros with dtype float32 ```
**Parameters:**
input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
**Returns:**
A new tensor filled with zeros matching the properties of the input.
**Return type:**
[Tensor](#max.tensor.Tensor)
## `current_realization_context()` {#max.tensor.current_realization_context} > max.tensor.current\_realization\_context() Return a value for the context variable for the current context. If there is no value for the variable in the current context, the method will: : \* return the value of the default argument of the method, if provided; or * return the default value for the context variable, if it was created with one; or * raise a LookupError. ## `default_device()` {#max.tensor.default_device} > max.tensor.default\_device(device) Context manager for setting the default device for tensor creation. Sets the default device used for tensor creation within the context. All tensors created inside the context block without an explicit device parameter will use this device. ```python from max import tensor from max.driver import CPU # Use CPU as default device in this context with tensor.default_device(CPU()): x = tensor.Tensor.ones((2, 3)) # Created on CPU y = tensor.Tensor.zeros((2, 3)) # Also on CPU ```
**Parameters:**
device ([Device](driver.md#max.driver.Device) | [DeviceRef](graph/type.md#max.graph.type.DeviceRef)) – The device to use as the default for tensor creation within the context.
**Returns:**
A context manager that sets the default device.
## `default_dtype()` {#max.tensor.default_dtype} > max.tensor.default\_dtype(dtype) Context manager for setting the default dtype for tensor creation. Sets the default data type used for tensor creation within the context. All tensors created inside the context block without an explicit dtype parameter will use this data type. ```python from max import tensor from max.dtype import DType # Use int32 as default dtype in this context with tensor.default_dtype(DType.int32): x = tensor.Tensor.ones((2, 3)) # Created with int32 y = tensor.Tensor.zeros((2, 3)) # Also int32 ```
**Parameters:**
dtype ([DType](dtype.md#max.dtype.DType)) – The data type to use as the default for tensor creation within the context.
**Returns:**
A context manager that sets the default dtype.
## `defaults()` {#max.tensor.defaults} > max.tensor.defaults(dtype=None, device=None) Gets the default dtype and device for tensor creation. Returns a tuple containing the dtype and device to use for tensor creation, applying defaults when values are not specified. If no dtype is provided, defaults to `DType.float32` for CPU and `DType.bfloat16` for accelerators. If no device is provided, defaults to an accelerator if available, otherwise CPU.
**Parameters:**
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type to use. If not specified, a default dtype based on the device is returned. * device ([Device](driver.md#max.driver.Device) | None) – The device to use. If not specified, defaults to an available accelerator or CPU.
**Returns:**
A tuple containing the resolved dtype and device.
**Return type:**
[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[DType](dtype.md#max.dtype.DType), [Device](driver.md#max.driver.Device)]
## `defaults_like()` {#max.tensor.defaults_like} > max.tensor.defaults\_like(like) Context manager setting the default dtype and device for tensor creation. Sets the default data type and device used for tensor creation within the context. All tensors created inside the context block without explicit dtypes or devices will use these parameters. ```python from max import tensor from max.driver import CPU from max.dtype import DType x = Tensor.zeros([1], dtype=DType.int32, device=CPU()) # Use int32 as default dtype in this context with tensor.defaults_like(x): y = tensor.Tensor.zeros((2, 3)) # int32, cpu z = tensor.Tensor.zeros((2, 3), dtype=DType.float32) # float32, cpu ```
**Parameters:**
* tensor – A tensor to use as the default dtype and device for the context. * like ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType))
**Returns:**
A context manager that sets the default dtype and device.
**Return type:**
[Generator](https://docs.python.org/3/library/collections.abc.html#collections.abc.Generator)\[None]
## `realization_context()` {#max.tensor.realization_context} > max.tensor.realization\_context(ctx) Sets the current realization context, within a context manager. New tensors created within this block will use the given realization context to execute. See [`RealizationContext`](#max.tensor.RealizationContext).
**Parameters:**
ctx ([RealizationContext](#max.tensor.RealizationContext)) – The realization context to set as the current context.
**Returns:**
A context manager. When the context manager is entered, it will set ctx as the current realization context. When exited the current realization context will be reset to its previous value.
**Return type:**
[AbstractContextManager](https://docs.python.org/3/library/contextlib.html#contextlib.AbstractContextManager)\[[RealizationContext](#max.tensor.RealizationContext)]
--- ## torch ## `CustomOpLibrary` {#max.torch.CustomOpLibrary} > class max.torch.CustomOpLibrary(kernel\_library) A PyTorch interface to custom operations implemented in Mojo. This API allows for easy passing of PyTorch data as `torch.Tensor` values to the corresponding custom op. `CustomOpLibrary` handles the compilation of the Mojo custom ops and marshalling of data between PyTorch and the executable Mojo code. For example, consider a grayscale operation implemented in Mojo: ```mojo title="my_library/grayscale.mojo" @register("grayscale") struct Grayscale: @staticmethod fn execute[ # The kind of device this is running on: "cpu" or "gpu" target: StaticString, ]( img_out: OutputTensor[dtype = DType.uint8, rank=2], img_in: InputTensor[dtype = DType.uint8, rank=3], ctx: DeviceContextPtr, ) raises: ... ``` You can then use `CustomOpLibrary` to invoke the Mojo operation like so: ```python import torch from max.torch import CustomOpLibrary op_library = CustomOpLibrary("my_library") grayscale_op = op_library.grayscale def grayscale(pic: torch.Tensor) -> torch.Tensor: result = pic.new_empty(pic.shape[:-1]) grayscale_op(result, pic) return result img = (torch.rand(64, 64, 3) * 255).to(torch.uint8) result = grayscale(img) ``` The custom operation produced by `op_library.` will have the same interface as the backing Mojo operation. Each `InputTensor` or `OutputTensor` argument corresponds to a [`torch.Tensor`](https://docs.pytorch.org/docs/stable/tensors.html#tensor-class-reference) value in Python. Each argument corresponding to an `OutputTensor` in the Mojo operation will be modified in-place. For more information, see the [custom ops for PyTorch](/max/tutorials/custom-kernels-pytorch) tutorial.
**Parameters:**
kernel\_library (Path | [KernelLibrary](graph/KernelLibrary.md#max.graph.KernelLibrary)) – The path to a `.mojo` file or a `.mojopkg` with your custom op kernels, or the corresponding library object.
## `graph_op()` {#max.torch.graph_op} > max.torch.graph\_op(fn=None, name=None, kernel\_library=None, input\_types=None, output\_types=None, num\_outputs=None) A decorator to create PyTorch custom operations using MAX graph operations. This decorator allows you to define larger graphs using [MAX graph ops](/max/api/python/graph/ops) or the MAX `nn` modules and call them with PyTorch tensors, or integrate them into PyTorch modules. These custom ops can be called eagerly, and support compilation with `torch.compile` and the Inductor backend. The resulting custom operation uses destination-passing style, where output tensors are passed as the first arguments and modified in-place. This allows PyTorch to manage the memory and streams of the output tensors. Tensors internal to the computation are managed via MAX’s graph compiler and memory planning. The default behavior is to JIT-compile for the specific input and output shapes needed. If you are passing variable-sized inputs, for instance a batch size or sequence length which may take on many different values between calls, you should specify this dimension as a symbolic dimension through `input_types` and `output_types`. Otherwise you will end up compiling specialized graphs for each possible variation of inputs, which may use a lot of memory. If neither output\_types nor num\_outputs is specified, default to 1 output. For example to create a functional-style PyTorch op backed by MAX: ```python title="grayscale.py" import torch import numpy as np import max.torch from max.dtype import DType from max.graph import ops @max.torch.graph_op def max_grayscale(pic: max.graph.TensorValue): scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07]) grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype) # max reductions don't remove the dimension, need to squeeze return ops.squeeze(grayscaled, axis=-1) @torch.compile def grayscale(pic: torch.Tensor): output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension max_grayscale(output, pic) # Call as destination-passing style return output device = "cuda" if torch.cuda.is_available() else "cpu" img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8) result = grayscale(img) print(f"Input shape: {img.shape}") print(f"Output shape: {result.shape}") print("Grayscale conversion completed successfully!") ```
**Parameters:**
* fn ([Callable](graph/ops.md#max.graph.ops.Callable)\[\[...], [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None] | None) – The function to decorate. If None, returns a decorator. * name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional name for the custom operation. Defaults to the function name. * kernel\_library (Path | [KernelLibrary](graph/KernelLibrary.md#max.graph.KernelLibrary) | None) – Optional kernel library to use for compilation. Useful for creating graphs with custom Mojo ops. * input\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorType](graph/type.md#max.graph.type.TensorType)] | None) – Optional sequence of input tensor types for compilation. If None, types are inferred from runtime arguments. * output\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorType](graph/type.md#max.graph.type.TensorType)] | None) – Optional sequence of output tensor types for compilation. If None, types are inferred from runtime arguments. * num\_outputs ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of outputs of the graph. We need to know this ahead of time to register with PyTorch before we’ve compiled the final kernels.
**Returns:**
A PyTorch custom operation that can be called with torch.Tensor arguments.
**Return type:**
CustomOpDef | [Callable](graph/ops.md#max.graph.ops.Callable)\[\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[…], [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None]], CustomOpDef]
--- ## What's new Here's everything you should know about what's changed in each release. ## Nightly: v26.2 This version is still a work in progress. See how to [install the nightly release](/max/packages#install). ### MAX models {#26-2-models} * Add support for Qwen/Qwen3-30B-A3B-Instruct-2507 which is a MOE model. * Add multi-GPU tensor parallelism support for Qwen3 and Qwen3-MoE models. * Remove legacy Gemma 3 multimodal implementation and the `MODULAR_MAX_DISABLE_GEMMA3_VISION` environment variable. * Implement multi-GPU support (tensor parallelism) for GPT-OSS. ### MAX framework {#26-2-max} #### Inference server {#26-2-max-serve} * Enabled overlap scheduling for select model architectures like `LlamaForCausalLM_Legacy` by default. This optimization reduces CPU overhead by overlapping python host code with GPU kernel execution. This optimization is currently incompatible with some features such as structured outputs or cpu models. This feature is very experimental! You can forcibly disable it via `--no-enable-overlap-scheduler --force`. #### Python API {#26-2-max-python} * Keep a global MLIR context active and drop per-graph context plumbing so algebraic dims and graph/custom op construction work without an explicit context manager. Threadpool-backed MAX paths now scope worker-thread MLIR usage to the default context automatically. ### MAX kernels {#26-2-max-kernels} ### Mojo language {#26-2-mojo} For all the updates to the Mojo language, standard library, and tools, including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo changelog](/mojo/changelog) ## v26.1 (2026-01-29) ### Highlights {#26-1-highlights} The eager-style [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) and [`Module`](/max/api/python/nn/module#max.nn.module.Module) APIs are now the primary API for model development, providing a PyTorch-like development experience: ```python from max import functional as F from max.tensor import Tensor from max.dtype import DType x = Tensor.constant([1.0, -2.0, 3.0, -4.0, 5.0], dtype=DType.float16) y = F.relu(x) print(y) # Tensor([1 0 3 0 5], dtype=DType.float16, device=Device(type=gpu,id=0)) ``` If you want explicit control over the graph structure, you can still build models with the [`Graph`](/max/api/python/graph/Graph) APIs. For more details, see the [model developer guide](/max/develop/). ### Documentation {#26-1-docs} * The fully refactored [MAX LLM book](https://llm.modular.com/) is now designed so the code you write in each exercise incrementally builds upon the last one, until you've built an executable GPT-2 model with the MAX Python API. * New model developer guide introduces [eager-style programming](/max/develop/), [tensor APIs](/max/develop/tensors), and [data types](/max/develop/dtypes). Much more is coming soon. * New guide to [profile MAX on GPUs with `nsys`](/max/gpu-system-profiling). * Extended [documentation for `kbench`](https://github.com/modular/modular/tree/main/max/kernels/benchmarks/autotune#kbench-a-benchmarking-toolkit-for-mojo-kernels), a Python tool to benchmark, autotune, and analyze MAX kernel performance. ### MAX models {#26-1-models} * [Gemma3](https://builds.modular.com/models/gemma-3-it/27B) now supports vision input (multimodal) in the 12B and 27B variants, including support for local file paths and structured output. Learn more in the [image to text guide](/max/inference/image-to-text). * Added `Qwen/Qwen3-VL-4B-Instruct` and `Qwen/Qwen3-VL-2B-Instruct` model architectures. * Removed Llama 3.2 Vision (`Llama-3.2-11B-Vision-Instruct`) architecture support. Use other vision models such as Pixtral, InternVL, Qwen2.5-VL, and Gemma3. ### MAX framework {#26-1-max} * All Python wheels are now hosted at `https://whl.modular.com/nightly/simple/`. If using `uv`, change `--index-url` to `--index`, and if using `pip`, change to `--extra-index-url`. For precise commands, see the [install guide](/max/packages#install). #### Inference server {#26-1-max-serve} * Improved scheduling to achieve higher KVCache utilization and batch sizes. By default, MAX now schedules a context encoding (CE) request only if KVCache memory is less than 95% full *after* allocating blocks for that request or if no active requests exist. You can adjust this watermark value (`0.95`) with [`--kvcache-ce-watermark`](/max/cli/serve#--kvcache-ce-watermark-kvcache_ce_watermark). Beware that increasing it causes more preemptions. * When running models with data-parallelism (DP), the semantics of max batch size has changed. For example, when specifying `--data-parallel-degree 8` and `--max-batch-size 32` it previously meant that each data-parallel replica could have at most 4 requests for an aggregate max batch size of 32. We changed this so that now the CLI flag specifies the max batch size per replica. This means the aggregate max batch size of the above values is 8\*32=256 requests. This aligns with vLLM and other inference engines. * `--max-ce-batch-size` is now deprecated. The cap on batch size is now uniform between context encoding and token generation phases of text generation. Use `--max-batch-size` instead. * The API server now returns chunked tokens from the model worker, reducing overhead and significantly improving throughput for small models and decode-heavy workloads. * Server stats collection (`collect_server_stats`) is now enabled by default for serving benchmarks. #### `max` CLI {#26-1-max-cli} * The `max generate` command now applies the model's chat template internally when using `--prompt`. This more closely aligns with how users typically prompt a model for testing and ensures special tokens are properly filtered from output. * Added tracing flags to `max benchmark` for `nsys` profiling: * `--trace`: Enable tracing of the benchmark run (currently NVIDIA GPUs only) * `--trace-file`: Path to save the trace file * `--trace-session`: Optional session name for tracing Requires the server to be run under `nsys launch`. Using `--gpu-profiling detailed` is recommended. #### Python API {#26-1-max-python} * The eager-style [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) APIs are now the primary API for model development, providing a PyTorch-like development experience. We moved the eager-style tensor APIs out of `experimental` and reorganized the `max.nn` module to make the eager module system the primary API (`nn.module_v3` is now `nn.module`). The previous [`max.nn`](/max/api/python/nn/) components are still available for backward compatibility in [`max.nn.legacy`](/max/api/python/nn/legacy/). * Renamed `max.driver.Tensor` to [`max.driver.Buffer`](/max/api/python/driver#max.driver.Buffer) to clarify that it represents a low-level memory buffer, not a tensor. The [`max.tensor.Tensor`](/max/api/python/tensor#max.tensor.Tensor) class remains the primary tensor type. * Added `forward()` method to [`Module`](/max/api/python/nn/module#max.nn.module.Module) to compute the output—it behaves the same as invoking the object as a callable (the `__call__()` method). * `accelerator_count()` now returns a non-zero value when called on an Apple silicon system. This means you can use this code: ```python device = CPU() if accelerator_count() == 0 else Accelerator() ``` And it defaults to using the available Apple silicon GPU. As a consequence, MAX graphs should in most cases be dispatched to run on Apple silicon GPUs. Note that most MAX models do not yet work on Apple silicon GPUs due to missing hardware-specific kernel pathways and other support, but this is an important step towards enabling MAX more broadly on Apple silicon GPUs. * Added `max.nn.module.rope` containing rotary embedding implementations, [`RotaryEmbedding`](/max/api/python/nn/rope/RotaryEmbedding) and [`TransposedRotaryEmbedding`](/max/api/python/nn/rope/TransposedRotaryEmbedding). * Added [`ArchConfig`](/max/api/python/pipelines/interfaces#max.pipelines.lib.interfaces.ArchConfig) and `ArchConfigWithKVCache`. Going forward, models that register with the MAX architecture registry must define a config that implements this protocol * Added `ops.complex.mul` for multiplying complex-valued tensors * Added `calculate_virtual_device_count()`, `calculate_virtual_device_count_from_cli()`, `load_max_buffer()` to [`max.driver`](/max/api/python/driver/). * Added [`TokenBuffer`](/max/api/python/interfaces#max.interfaces.TokenBuffer) for token management. * Renamed `prefill_chunk_size` to `max_batch_input_tokens` and `max_batch_context_length` to `max_batch_total_tokens` in [`PipelineConfig`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig) and `TTSConfig` classes to better reflect their purpose in batch memory management. The corresponding CLI flags have also been renamed: `--prefill-chunk-size` is now `--max-batch-input-tokens` and `--max-batch-context-length` is now `--max-batch-total-tokens`. * Fixed `max.driver.Buffer.to(stream)` to not copy (it return reference to the same tensor) when the stream is on the same device, even for GPU-pinned host memory. * Removed deprecated `max.nn` convolution classes: `Conv2dV1`, `Conv1DV1`, `Conv3DV1`. Use `Conv2d`, `Conv1D`, `Conv3D` instead. * Removed deprecated `max.nn` layer classes: `LinearV1`, `QLinearV1`, `GPTQLinearV1`, `MLPV1`, `EmbeddingV1`, `LayerNormV1`, `RMSNormV1`. Use `Linear`, `GPTQLinear`, `MLP`, `Embedding`, `LayerNorm`, `RMSNorm` instead. * Removed `max.engine.MojoValue` * Removed the deprecated `custom_ops_path` parameter from [`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load). Instead use the `custom_extensions` parameter. * Added `graph.ops.shard_and_stack()` * Removed unused `graph.weights.PytorchWeights` ### MAX kernels {#26-1-max-kernels} * Improved performance for Hopper Matmul when using skinny M shapes. In particular when M is between 2 and 64, we see a significant performance boost for specific shapes ranging between 10 - 40%. * Added swapAB optimization to Hopper Matmul, performs B x A and does a tranposed write to C. This helps when you need more granularity in the M dimension. * Refined `create_stream` API: all streams are now non-blocking (`blocking` argument has been removed). Explicitly use `DeviceEvent` and `synchronize()` wherever necessary. ### Mojo language {#26-1-mojo} For all the updates to the Mojo language, standard library, and tools, including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo changelog](/mojo/changelog) ## v25.7 (2025-11-20) ### Highlights {#25-7-highlights} * The MAX Python API is now [fully open-sourced on GitHub](https://github.com/modular/modular/tree/main/max/python/max)! As we expand our [model repository](https://builds.modular.com/?category=models), we're making significant progress on these APIs to simplify the effort to build production-ready GenAI models in Python. Some APIs are still experimental, but you can [build an LLM with it today](https://llm.modular.com). ### Documentation {#25-7-docs} * New online book to [build an LLM from scratch with MAX](https://llm.modular.com), using our **experimental model APIs**. This is a guided lesson to building GPT-2 with our Python API, explaining each component of the transformer model along the way. Like the Python APIs, the book is a work in progress—please [report any issues in GitHub](https://github.com/modular/max-llm-book/issues). * All the planned parts of [GPU Puzzles](https://puzzles.modular.com/) are now complete! Support for Apple silicon GPUs is also making [steady progress](https://puzzles.modular.com/howto.html#gpu-support-matrix). * Tutorials on docs.modular.com are now integrated into the [Guides](/max/intro) section, indicated with a book icon in the left navigation. * The [`max` CLI docs](/max/cli/) are now generated from [the CLI source](https://github.com/modular/modular/blob/main/max/python/max/entrypoints/pipelines.py). ### MAX models {#25-7-models} * Gemma3 now supports logprobs. ### MAX framework {#25-7-max} * Added support for bfloat16 models running on GPUs with ARM-based CPU hosts, such as Grace Hopper (GH200) and Grace Blackwell (GB200). * Updated minimum NVIDIA GPU driver requirement to 580. #### `max` CLI {#25-7-max-cli} * [`max benchmark`](/max/cli/benchmark) can now run LoRA benchmarking for supported models and target modules. * `max benchmark --collect-gpu-stats` can now collect AMD GPU statistics. * `max serve --do-penalties` was renamed to `--enable-penalties` and enabled by default. To disable penalties, you can specify [`--no-enable-penalties`](/max/cli/serve#--enable-penalties---no-enable-penalties) #### Python API {#25-7-max-python} * Added support for Python 3.14. * Removed support for Python 3.9. * All MAX Python API modules are now **open-sourced**. In addition to those previously released, we've added `driver`, `dtype`, `engine`, `experimental`, `interfaces`, `kv_cache`, `mlir`, `nn`, `profiler`, `support`, `torch`, and more [in our GitHub repo](https://github.com/modular/modular/tree/main/max/python/max). * Added [`max.profiler`](max/api/python/profiler) module with the [`Tracer`](/max/api/python/profiler#max.profiler.Tracer) class to create and manage profiling spans based on runtime conditions, and the \[\`@traced()] decorator to profile a whole function. * Added [`max.diagnostics.gpu`](/max/api/python/diagnostics/gpu) APIs to expose common GPU statistics as might be reported by `nvidia-smi` or `rocm-smi`. * Added the [`max.kv_cache`](/max/api/python/kv_cache/) package, which provides APIs to manage key-value caches used in transformer models. Not to be confused with the existing [`max.nn.kv_cache`](/max/api/python/nn/kv_cache/) package that includes kernels for KV caching. * Removed the `KVCacheManager` class and combined it with the single [`PagedKVCacheManager`](/max/api/python/kv_cache/paged_cache/cache_manager#max.kv_cache.paged_cache.cache_manager.PagedKVCacheManager) implementation. During merger, `prefetch()` was renamed `maybe_reserve()`. * Added [`NullKVCacheManager`](/max/api/python/kv_cache/null_cache_manager#max.kv_cache.NullKVCacheManager) for compile-only mode, which avoids GPU memory allocation when compiling models without a physical GPU present. * Added [`ResetPrefixCacheBackend`](/max/api/python/kv_cache/paged_cache/tp_cache_manager#max.kv_cache.paged_cache.ResetPrefixCacheBackend) and [`ResetPrefixCacheFrontend`](/max/api/python/kv_cache/paged_cache/tp_cache_manager#max.kv_cache.paged_cache.ResetPrefixCacheFrontend) classes for coordinating prefix cache resets between frontend and backend components. * Added more APIs for text-to-speech (TTS) models such as [`AudioGenerationInputs`](/max/api/python/interfaces#max.interfaces.AudioGenerationInputs) and [`AudioGenerationOutput`](/max/api/python/interfaces#max.interfaces.AudioGenerationOutput) * Changed [`LoRAConfig.max_num_loras`](/max/api/python/pipelines/lora_config#max.pipelines.lib.lora_config.LoRAConfig.max_num_loras) default to `1` (was `100`). * New [`RequestID`](/max/api/python/interfaces/#max.interfaces.RequestID) class replaces previous type alias to provide better type safety and consistency across the API. * Removed `InputContext` and replaced it with the modality-output specific [`TextGenerationContext`](/max/api/python/interfaces/#max.interfaces.TextGenerationContext) and [`EmbeddingsContext`](/max/api/python/interfaces/#max.interfaces.EmbeddingsContext). * Added [`ImageMetadata`](/max/api/python/interfaces/#max.interfaces.ImageMetadata) and [`VLMTextGenerationContext`](/max/api/python/interfaces/#max.interfaces.VLMTextGenerationContext). * Added [`max.nn.comm`](/max/api/python/nn/comm/) with `Allreduce` and `Signals` for peer-to-peer communication in allreduce. * [`ops.gather()`](/max/api/python/graph/ops#max.graph.ops.gather) no longer has a default `axis`, it must be specified explicitly (better matching PyTorch and NumPy). * [`Graph.add_subgraph()`](/max/api/python/graph/Graph#max.graph.Graph.add_subgraph) has been updated to take a `devices` argument. This allows subgraphs to take advantage of device-aware work scheduling. #### Mojo API {#25-7-max-mojo} * Renamed the `tensor_internal` package to `tensor` and removed the previous `tensor` stub—the API behaves the same but the [Mojo `tensor` docs](/mojo/kernels/extensibility/tensor/) moved. ### Mojo language {#25-7-mojo} For all the updates to the Mojo language, standard library, and tools, including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo changelog](/mojo/changelog). ## v25.6.1 (2025-10-10) Fixes a latency regression due to a top-k algorithm change and a couple other benchmarking bugs. ## v25.6 (2025-09-22) * [Highlights](#25-6-highlights) * [Documentation](#25-6-docs) * [MAX models](#25-6-models) * [MAX framework](#25-6-max) * [Inference server](#25-6-max-serve) * [`max` CLI](#25-6-max-cli) * [Python API](#25-6-max-python) * [MAX kernels](#25-6-kernels) * [Mojo language](#25-6-mojo) ### Highlights {#25-6-highlights} * MAX delivers **state-of-the-art performance on NVIDIA Blackwell** (B200)! We've been describing our Blackwell bring-up over a series of blog posts, and we recently published [Part 4: Breaking SOTA](https://www.modular.com/blog/matrix-multiplication-on-blackwell-part-4---breaking-sota), in which we share our latest matmul benchmarks compared to NVIDIA's cuBLAS library. * MAX provides **industry-leading performance on AMD MI355X**! In a matter of weeks, we got MAX running on the brand new MI255X system and have already produced early benchmarks that go head-to-head with Blackwell. If you have access to an MI355X, you can try it yourself today by following our [quickstart guide](/max/get-started). * Benchmarking endpoints is easier than ever before the new [`max benchmark`](/max/cli/benchmark) command, which accepts YAML configuration files so you can easily share and reproduce your benchmarks. ### Documentation {#25-6-docs} * Our new [quickstart guide](/max/get-started) lets you pick the model architecture and size you want, and then shows you how to deploy it and run our open-source benchmarking script, all from the `max` CLI. * We updated and simplified the [benchmarking tutorial](/max/deploy/benchmark) to use the new `max benchmark` command. ### MAX models {#25-6-models} * Added the [gpt-oss](https://github.com/modular/modular/tree/modular/v25.6.0/max/pipelines/architectures/gpt_oss) model architecture (GPU, bfloat16). [Try GPT-OSS now](https://builds.modular.com/models/gpt-oss-20b-BF16/20B). ### MAX framework {#25-6-max} * Added device-aware work scheduling for AsyncRT: work items can now specify a `deviceHint` to route execution to specific worker threads based on device affinity, improving multi-device performance. * Improved code quality by enabling large set of RUFF lints, including [flake8-annotations (ANN)](https://docs.astral.sh/ruff/rules/#flake8-annotations-ann) which now enforces Python type annotations for new contributions. #### Inference server {#25-6-max-serve} * Added support for data parallelism in Llama models. To enable this feature, use the `--data-parallel-degree` option: ```sh max serve --model $MODEL_ID --data-parallel-degree 2 --devices gpu:0,1 ``` * Metrics for each context encoding and token generation batch are now logged to the console periodically. We can override the default frequency (3 seconds) of such logs via setting the `MAX_SERVE_SCHEDULER_STATS_LOG_INTERVAL_S` flag. For example, setting `MAX_SERVE_SCHEDULER_STATS_LOG_INTERVAL_S=0` will log metrics for all batches. * Improved error messages when pulling a model that requires more RAM than what's available or when there won't be enough RAM left for the KV cache. #### `max` CLI {#25-6-max-cli} * Added the `max benchmark` subcommand that runs a suite of benchmarks and collects performance metrics on a model server. This command provides convenient packaging/installation for our open-source [`benchmark_serving.py`](https://github.com/modular/modular/tree/main/benchmark#benchmark-max) script and accepts all the same options. * Added `--chat-template` to the CLI for passing a custom chat templates defined in Jinja2 template files. * Renamed the `--allow-safetensors-weights-float32-to-bfloat16-cast` flag to `--allow-safetensors-weights-fp32-bf6-bidirectional-cast`, which supports automatic bidirectional dtype casts when needed. * The `max generate` command now supports `--top-k`, `--temperature`, and `--seed` flags. * Changed `--num-warmups` behavior. Previously, it ran the model on the prompt `N` times, generating until reaching a stop condition each time. Now it runs the model for `N` steps, generating `N` new tokens as a warmup. * Added the `--model` option as a preferred alternative to `--model-path`. They behave the same. * Deprecated `--pad-to-multiple-of`. * Removed the previously deprecated `--model-name`. Use `--served-model-name` instead. #### Python API {#25-6-max-python} * Removed the previously deprecated `KVCacheStrategy.CONTINUOUS` and all associated classes (including `ContinuousBatchingKVCacheManager`). * Added [`ops.fence`](/max/api/python/graph/ops#max.graph.ops.fence), a pure identity operation that prevents the async runtime from reordering operations across it. This operation is essential for implementing cross-device synchronization. * Removed `PipelineConfig.max_new_tokens`. Use [`SamplingParams.max_new_tokens`](/max/api/python/pipelines#max.pipelines.SamplingParams) instead. * Added [`logits_processor`](/max/api/python/interfaces/#max.interfaces.SamplingParams.logits_processors) to [`SamplingParams`](/max/api/python/interfaces/#max.interfaces.SamplingParams) for updating logits in-place during each step of token generation. * Added `generate()` to [`TextGenerationPipeline`](/max/api/python/pipelines/pipeline#max.pipelines.lib.pipeline.TextGenerationPipeline) and [`SpeculativeDecodingPipeline`](/max/api/python/pipelines#max.pipelines.SpeculativeDecodingPipeline), a convenience method for getting text generations. `generate_async()` is available for getting streamed outputs. * Renamed the `target_num_new_tokens` configuration parameter to [`prefill_chunk_size`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig.prefill_chunk_size) in [`PipelineConfig`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig) and `TTSConfig` classes to better reflect its role in chunked prefill operations. * Fixed [`ops.range`](/max/api/python/graph/ops#max.graph.ops.range) to respect the `dtype` parameter when using [`Dim`](/max/api/python/graph/dim) objects as inputs. Previously, the dtype was ignored and defaulted to int64. * Made the `devices` argument in [`InferenceSession()`](/max/api/python/engine#max.engine.InferenceSession) required. To maintain the previous default behavior, use `InferenceSession(devices=[CPU()])`. * Added an optional `logging` argument to [`InferenceSession()`](/max/api/python/engine#max.engine.InferenceSession). When set to `"op"`, this option enables operation launch output to stderr. * Added [`max.nn.lora`](/max/api/python/nn/lora), providing Low-Rank Adaptation (LoRA) support for parameter-efficient fine-tuning of neural network models. * Added [`max.nn.moe`](/max/api/python/nn/moe), implementing Mixture of Experts (MoE) layers for scalable model architectures. * Added [`max.nn.sampling`](/max/api/python/nn/sampling), containing advanced sampling methods including MinP and rejection sampling techniques. * Added [`max.nn.hooks`](/max/api/python/nn/hooks), providing debugging and inspection hooks for neural network layers. * Added attention submodules [`max.nn.attention.mask_config`](/max/api/python/nn/attention/mask_config), [`max.nn.attention.multihead_attention`](/max/api/python/nn/attention/multihead_attention), and [`max.nn.attention.multi_latent_attention`](/max/api/python/nn/attention/multi_latent_attention) for comprehensive attention mechanism configuration and implementation. * Moved some Mojo-related functionality to a new top-level `mojo` Python namespace. Specifically, `max.mojo` (previously used for Mojo-Python interop), some of `max.support`, and `max.entrypoints.mojo` now live under the `mojo` namespace and are provided in the new [`mojo` package](/mojo/manual/install#whats-included). ### MAX kernels {#25-6-kernels} * Added a leaky ReLU activation function kernel. * Added a specialized [RMS norm](/mojo/kernels/nn/normalization/rms_norm/) function kernel for the common case of `cols=128`, `bfloat16`. ### Mojo language {#25-6-mojo} For all the updates to the Mojo language, standard library, and tools, including all GPU programming changes, see the [Mojo changelog](/mojo/changelog). ## v25.5 (2025-08-05) * [Highlights](#25-5-highlights) * [Documentation](#25-5-docs) * [MAX models](#25-5-models) * [MAX framework](#25-5-max) * [Inference server](#25-5-max-serve) * [`max` CLI](#25-5-max-cli) * [Python API](#25-5-max-python) * [Mojo language](#25-5-mojo) ### Highlights {#25-5-highlights} * **OpenAI-compatible batch API**: The [`/v1/batches` API](/max/api/serve#operation/createBatch) is now available with [Mammoth](/mammoth/). We recently announced a [partnership with SF Compute](https://www.modular.com/blog/sf-compute) to make this API available through their dynamic GPU pricing marketplace. Their Large Scale Inference Batch API looks different from the `/v1/batches` API in Mammoth because it's a superset. * **New `mojo` Conda package**: For Mojo-specific projects that run on CPUs and GPUs, you can now install the bare essentials with the `mojo` Conda package that's less than 900 MB on disk. For example, this now works: ```sh pixi add mojo ``` The `mojo` Python package is not available for pip/uv yet. For a complete model-development and serving toolkit, you should still install the `modular` package (which includes `mojo` as a dependency). * **Open-source graph APIs**: We've added the `max.graph` Python APIs to our [GitHub repo](https://github.com/modular/modular/tree/modular/v25.5.0/max/graph). We've made great strides in recent months to simplify these APIs that help you build high-performance models you can [serve with MAX](/max/develop/serve-custom-model-architectures). ### Documentation {#25-5-docs} * New [Serve custom model architectures tutorial](/max/develop/serve-custom-model-architectures), with [example code on GitHub](https://github.com/modular/modular/tree/main/max/examples/custom-models). * New guide for [using LoRA adapters with MAX](/max/serve/lora-adapters). * Updated the [Deploy Llama 3 on GPU tutorial](/max/tutorials/max-serve-local-to-cloud/) with instructions using AMD MI300X (on Azure). * Added [Pixi basics](/pixi), which is where we redirect all the now-removed Magic docs (see our [announcement migrating Magic to Pixi](https://forum.modular.com/t/migrating-from-magic-to-pixi/1530)). ### MAX models {#25-5-models} * Added support for [Idefics3](https://github.com/modular/modular/tree/modular/v25.5.0/max/pipelines/architectures/idefics3) model. ### MAX framework {#25-5-max} * Removed all `torch` package dependencies. * Reduces the total installation size of `modular` (including dependencies) from 2.2 GB for CPUs and 6.5 GB for GPUs **down to 1.5 GB**, for all Python packages. Conda packages pull additional system dependencies so sizes may vary, but one example brings the size down from 9.8 GB to 2.0 GB. * `pip install` no longer requires the `--extra-index-url https://download.pytorch.org/whl/cpu` option (which was to avoid installing the GPU version of `torch` that has a lot of CUDA dependencies). * `uv pip install` no longer requires the `--index-strategy unsafe-best-match` option (which was to avoid package resolution issues with the above `--extra-index-url` option). * Removed HuggingFace fallback for model pipelines not natively supported in MAX (`PipelineEngine.HUGGINGFACE`), because it's almost never used and it creates significant tech debt. #### Inference server {#25-5-max-serve} * Added the [`/health` endpoint](/max/api/serve/#operation/health) for service readiness checks, used by tools like lm-eval to determine when the service is ready to accept requests. * [Prefix caching](/max/serve/prefix-caching) now uses a Mojo token hashing operation. Previously we used the `hash()` method from the Python stdlib. However, this resulted in noticeable CPU overhead and reduced GPU utilization. In this release, we migrated the token hashing operation to an accelerated Mojo implementation. * Re-implemented the OpenAI API's `logprobs` and `echo` request parameters to eliminate an expensive device transfer. The `--enable-echo` flag, which previously incurred a significant performance penalty, is now 9-12x faster. * Added support for `file://` URIs in image inputs for multimodal models. Local file access is controlled via the `MAX_SERVE_ALLOWED_IMAGE_ROOTS` environment variable, which specifies a list of allowed root directories. Files are read asynchronously using aiofiles for better performance under high load. * Improved [function calling](/max/serve/function-calling) (tool use) to more reliably extract JSON tool calling responses for Llama models in an OpenAI-compatible format. * Switched from XGrammar to [llguidance](https://github.com/guidance-ai/llguidance) for generating structured output (constrained decoding). #### `max` CLI {#25-5-max-cli} * Added `--vision-config-overrides` CLI option to override vision model configuration parameters. For example, to decrease InternVL's maximum dynamic patches from 12 to 6: ```bash max serve --model-path OpenGVLab/InternVL3-38B-Instruct \ --vision-config-overrides '{"max_dynamic_patch": 6}' ``` * Removed `--ignore-eos` CLI argument. The full set of OpenAI chat and completion sampling parameters are now supported in the http requests. As such, the parameter can just be set via the http payload. #### Python API {#25-5-max-python} * Added the [`max.interfaces`](/max/api/python/interfaces) module. This module should serve as a relatively import free module to hold all shared interfaces across the MAX stack. Slowly we will be moving common interfaces to this module. So far, we've moved the following from `max.pipelines.core`: * Moved `TextGenerationStatus`, `TextResponse`, `TextGenerationResponse`, `InputContext`, and `PipelineTask` into `max.interfaces`. * Moved all `TokenGeneratorRequest`-prefixed objects into `max.interfaces` and renamed with the `TextGenerationRequest` prefix. * Moved `TextGenerationStatus` to [`GenerationStatus`](/max/api/python/interfaces/#max.interfaces.GenerationStatus). * Moved `TextResponse` and `TextGenerationResponse` to [`TextGenerationOutput`](/max/api/python/interfaces/#max.interfaces.TextGenerationOutput). * Moved `EmbeddingsResponse` to [`EmbeddingsOutput`](/max/api/python/interfaces#max.interfaces.EmbeddingsOutput). * Added [`ops.scatter_nd`](/max/api/python/graph/ops/#max.graph.ops.scatter_nd) operation for scattering updates into a tensor at specified indices. * Added [`ops.avg_pool2d`](/max/api/python/graph/ops/#max.graph.ops.avg_pool2d) and [`ops.max_pool2d`](/max/api/python/graph/ops/#max.graph.ops.max_pool2d). * Added [`max.torch.graph_op`](/max/api/python/torch#max.torch.graph_op) interface to make it simple to embed larger MAX computations and models inside PyTorch. These can use `max.nn` modules internally and may be used within `torch.nn` modules, allowing the use of MAX subcomponents for access to our high performance graph compiler and Mojo kernel library. ```python import torch import numpy as np import max from max.dtype import DType from max.graph import ops @max.torch.graph_op def max_grayscale(pic: max.graph.TensorValue): scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07]) grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype) # max reductions don't remove the dimension, need to squeeze return ops.squeeze(grayscaled, axis=-1) @torch.compile def grayscale(pic: torch.Tensor): output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension max_grayscale(output, pic) # Call as destination-passing style return output img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8) result = grayscale(img) ``` * Moved `AlgebraicDim`, `Dim`, `StaticDim`, and `SymbolicDim` out of `max.type` and into [`max.graph.dim`](/max/api/python/graph/dim). You can still import them directly from `max.graph`. * Moved `Shape` out of `max.type` and into [`max.graph.shape`](/max/api/python/graph/shape). You can still import it directly from `max.graph`. * Removed the ability to pass Python objects into models and have them returned as Mojo `PythonObject` types in the kernels. * Removed `RandomWeights`. * Removed `Model.execute_legacy()`. Instead use the standard [`execute()`](/max/api/python/engine#max.engine.Model.execute) or [`__call__()`](/max/api/python/engine#max.engine.Model.__call) methods. * Removed TorchScript-related helper functions and APIs, including support for `.pt` TorchScript files in custom extensions. ### Mojo language {#25-5-mojo} For all the updates to the Mojo language, standard library, and tools, including all GPU programming changes, see the [Mojo changelog](/mojo/changelog). ## v25.4 (2025-06-18) * [Highlights](#25-4-highlights) * [Documentation](#25-4-docs) * [MAX models](#25-4-models) * [MAX framework](#25-4-max) * [Inference server](#25-4-max-serve) * [`max` CLI](#25-4-max-cli) * [Python API](#25-4-max-python) * [Mojo API](#25-4-max-mojo) * [Custom ops](#25-4-custom-ops) * [GPU programming](#25-4-gpu-programming) * [Mojo language](#25-4-mojo) ### ✨ Highlights {#25-4-highlights} * **AMD GPUs are officially supported!** You can now deploy MAX with acceleration on AMD MI300X and MI325X GPUs, using the same code and container that works on NVIDIA GPUs. For the first time, you can build portable, high-performance GenAI deployments that run on any platform without vendor lock-in or platform-specific optimizations. For more details, including benchmarks, see our [Modular + AMD blog post](https://www.modular.com/blog/modular-x-amd-unleashing-ai-performance-on-amd-gpus). * **Now accepting GPU kernel contributions** Last month, we open-sourced the code for the CPU and GPU kernels that power the MAX framework, and now we're accepting contributions! For information about how to contribute and the sort of kernels most interesting to us, see the [MAX AI kernels contributing guide](https://github.com/modular/modular/blob/main/max/kernels/CONTRIBUTING.md). * **Preview: Mojo interoperability from Python** This release includes an early version of a new Python-to-Mojo interoperability API. You can now write just the performance-critical parts your code in Mojo and call it from Python just like you're importing another Python library. Check out our docs to [call Mojo from Python](/mojo/manual/python/mojo-from-python). ### Documentation {#25-4-docs} We've redesigned [builds.modular.com](https://builds.modular.com) and [docs.modular.com](https://docs.modular.com) with a unified top navigation bar that so you can more easily discover all the available docs and code resources. New docs: * [GPU Puzzles](https://builds.modular.com/puzzles/introduction.html): Several new puzzles, including: 1D convolution op, softmax op, attention op, embedding op, kernel fusion, custom backward pass, GPU functional programming patterns, and warp fundamentals. * [Using AI coding assistants guide](/max/coding-assistants): Learn how to use large language models (LLMs) and coding assistants (such as Cursor and Claude Code) to accelerate your development with Modular. * [Build an MLP block as a graph module tutorial](/max/develop/build-an-mlp-block): Learn how to create reusable `Module` components in your MAX graphs. * [Write custom ops for PyTorch tutorial](/max/develop/custom-kernels-pytorch) (Beta feature): Learn to write high-performance GPU kernels for your PyTorch models with Mojo. * [Profile MAX kernel performance](https://github.com/modular/modular/blob/main/max/docs/kernel-profiling.md): Learn how to set up Nsight Compute to profile your Mojo-based kernels on NVIDIA GPUs. Major updates: * [Build custom ops for GPUs tutorial](/max/develop/build-custom-ops): Now includes how to write hardware-specific functions for CPUs and GPUs. * [Optimize a matrix multiply custom op tutorial](/max/develop/custom-ops-matmul): Migrated from a Recipe with revisions to help you improve the performance of your GPU custom ops. ### MAX models {#25-4-models} * Added the OLMo 2 model architecture ([`olmo2`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/olmo2)). [Try OLMo 2 now](https://builds.modular.com/models/OLMo-2-1124/7B). * Added Google's Gemma 3 multimodal model architecture ([`gemma3multimodal`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/gemma3)). [Try Gemma3 now](https://builds.modular.com/models/gemma-3-it/1B). * Added the Qwen 3 model architecture ([`qwen3`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/qwen3)). [Try Qwen3 now](https://builds.modular.com/models/Qwen3/1.7B). * Added the InternVL3 model architecture ([`internvl`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/internvl)). This is still a work in progress. * GGUF quantized Llamas (q4\_0, q4\_k, and q6\_k) are now supported with paged KVCache strategy. ### MAX framework {#25-4-max} #### Inference server {#25-4-max-serve} * Inflight batching no longer requires chunked prefill. * Expanded token sampling logic, including top\_k, min\_p, min\_new\_tokens, temperature. * Extended sampling configuration to be per-request, e.g. different requests can ask for different sampling hyperparameters. * Removed support for TorchScript and torch MLIR models. #### `max` CLI {#25-4-max-cli} * Added the `--use-subgraphs` flag to `max generate` to allow for the use of subgraphs in the model. * Added the `--port` option to specify the port number with the `max serve` command. #### Python API {#25-4-max-python} * Lots of new APIs in the [`max.nn`](/max/api/python/nn/) package. * Added `max.mojo.importer` module to import Mojo code into Python. See the docs for [calling Mojo from Python](/mojo/manual/python/mojo-from-python). * Added [`Graph.add_subgraph()`](/max/api/python/graph/Graph#max.graph.Graph.add_subgraph) to allow for the addition of a subgraph to a graph. * Added [`Module.build_subgraph()`](/max/api/python/nn/module#max.nn.module.Module.build_subgraph) to allow for the creation of a subgraph for a layer that inherits from `Module`. * Added the [`call`](/max/api/python/graph/ops#max.graph.ops.call) op which allows for the execution of a subgraph. * Added the [`fold`](/max/api/python/graph/ops#max.graph.ops.fold) op for combining sliding blocks into a larger tensor. * Added [`KernelLibrary`](/max/api/python/graph/KernelLibrary) as an argument type for the [`Graph`](/max/api/python/graph/Graph) constructor. * Added [`QuantizationConfig`](/max/api/python/graph/quantization#max.graph.quantization.QuantizationConfig) to specify quantization parameters for ops such as [`qmatmul()`](/max/api/python/graph/ops#max.graph.ops.qmatmul). * Added the `strict` argument to the [`Module.load_state_dict()`](/max/api/python/nn/module#max.nn.module.Module.load_state_dict) method. When `strict=True` (default), an error is raised if the `state_dict` contains unused keys. When `strict=False`, extra keys are ignored. This helps model developers identify missing implementations in their models. * Added audio generator APIs for text-to-speech models (such as [`AudioGenerator`](/max/api/python/pipelines/core#max.pipelines.core.AudioGenerator), [`PipelineAudioTokenizer`](/max/api/python/pipelines/core#max.pipelines.core.PipelineAudioTokenizer), [`TTSContext`](/max/api/python/pipelines/core#max.pipelines.core.TTSContext), and others). This is still a work in progress. * The [`ops.masked_scatter()`](/max/api/python/graph/ops#max.graph.ops.masked_scatter) function now requires naming the `out_dim` explicitly as it is data-dependent. For example: ```python ops.masked_scatter( inputs_embeds, video_mask, video_embeds, out_dim="unmasked_inputs" ) ``` * Deprecated the `CONTINUOUS` KVCache strategy ([`KVCacheStrategy`](/max/api/python/nn/kv_cache/cache_params/#max.nn.kv_cache.cache_params.KVCacheStrategy)). Please use `PAGED` KVCache strategy instead. * Removed the `Settings` argument from [`LLM`](/max/api/python/entrypoints#max.entrypoints.llm.LLM) constructor. The server is now automatically configured in the background without consuming an HTTP port. * Removed `Graph.unique_symbolic_dim()`. * Removed `max_to_torch_type()` and `torch_to_max_type()` and replaced them with [`DType.to_torch()`](/max/api/python/dtype#max.dtype.DType.to_torch) and [`DType.from_torch()`](/max/api/python/dtype#max.dtype.DType.from_torch), respectively. This aligns with the corresponding NumPy methods. * Removed `stats_report` property and `reset_stats_report` method from [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession). This functionality was primarily used for internal PyTorch debugging and is no longer needed. * Removed the naive KVCache (`nn.kv_cache.naive_cache`). * Removed `nn.attention` and `nn.naive_attention_with_rope`. * Renamed `ops.select` to [`ops.where`](/max/api/python/graph/ops#max.graph.ops.where). This matches the name of the similar operation in torch and numpy. #### Mojo API {#25-4-max-mojo} * [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/) now has a `size` method to get the total number of elements. * Following our [previous deprecation](#25-3-engine-mojo-api) of the Mojo `max.driver`, `max.graph` and `max.engine` APIs, we've removed them from the package and API docs. As a result, we've also removed Mojo `max.tensor` APIs (including `Tensor`, `TensorShape`, and `TensorSpec`). You can replace any use with [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/). #### Custom ops {#25-4-custom-ops} * Improved error messages when custom op parameters are provided with values that don't have the proper type. * The [`ops.custom()`](/max/api/python/graph/ops#max.graph.ops.custom) function now requires a `device` argument to specify where the operation should execute. This avoids the need for custom ops to infer their execution device, which can be error-prone. * Added the [`max.torch`](/max/api/python/torch) module with the `CustomOpLibrary` class for using custom Mojo kernels from PyTorch. For example, with a custom `grayscale` operation written in Mojo: ```mojo @register("grayscale") struct Grayscale: @staticmethod fn execute[ # The kind of device this is running on: "cpu" or "gpu" target: StaticString, ]( img_out: OutputTensor[dtype = DType.uint8, rank=2], img_in: InputTensor[dtype = DType.uint8, rank=3], ctx: DeviceContextPtr, ) raises: ... ``` You can load it with PyTorch like so: ```python from max.torch import CustomOpLibrary op_library = CustomOpLibrary("path/to/custom.mojopkg") @torch.compile(backend=backend) def grayscale(pic): result = pic.new_empty(pic.shape[:-1]) op_library.grayscale(result, pic) return result img = (torch.rand(64, 64, 3) * 255).to(torch.uint8) result = grayscale(img) ``` See our [tutorial to write custom ops for PyTorch](/max/develop/custom-kernels-pytorch), and out [PyTorch custom operation examples](https://github.com/modular/modular/tree/main/max/examples/pytorch_custom_ops), which range from a very basic "hello world" to the replacement of a layer in a full model. #### GPU programming {#25-4-gpu-programming} * Full support for AMD CDNA3 datacenter GPUs is now available! Specifically, MI300X and MI325X. * Added initial support for programming on AMD RDNA3 consumer GPUs. Basic tuning parameters have been specified for AMD Radeon 780m integrated GPUs. (AMD RDNA3 support is for GPU programming only; AI models are still missing some GPU kernels for this architecture.) For details, see the [GPU requirements](/max/packages#gpu-compatibility). * Now accepting CPU and GPU kernel contributions. See the [MAX AI kernels contributing guide](https://github.com/modular/modular/blob/main/max/kernels/CONTRIBUTING.md). ### Mojo language {#25-4-mojo} For all the updates to the Mojo language, standard library, and tools, see the [Mojo changelog](/mojo/changelog). ## v25.3 (2025-05-06) * [Highlights](#25-3-highlights) * [Documentation](#25-3-docs) * [`max` CLI](#25-3-max-cli) * [MAX models](#25-3-models) * [MAX Serve](#25-3-serve) * [MAX Engine & Graph](#25-3-engine) * [Python API](#25-3-engine-mojo-api) * [Mojo API](#25-3-engine-mojo-api) * [Custom ops](#25-3-custom-ops) * [Kernels](#25-3-kernels) * [GPU programming](#25-3-gpu-programming) * [Mojo language](#25-3-mojo) ### ✨ Highlights {#25-3-highlights} * You can now **install Modular APIs and tools with pip**: ```sh pip install modular \ --index-url https://download.pytorch.org/whl/cpu ``` This installs the `max` CLI, `max` Python library, `mojo` CLI, and Mojo libraries. However, the Mojo LSP and debugger are currently not included. We use the `--index-url` argument to ensure that `torch` installs its CPU dependencies only, thus avoiding a lot of unnecessary GPU packages. This is a temporary workaround until we can remove our dependency on `torch`. * We **open-sourced the MAX AI kernels** and the rest of the **Mojo standard library**! The [MAX AI kernels library](/mojo/lib#max-ai-kernels-library) is a new Mojo API for writing high-performance and portable programs across CPU and GPU, but it's also [the source code for our CPU/GPU kernels](https://github.com/modular/modular/tree/main/max/kernels/src). You can now see the Mojo code we use in MAX to power GenAI workloads on CPUs and GPUs. Just like the Mojo standard library, these kernels are open source under the Apache 2.0 License with LLVM exceptions. Plus, the rest of the Mojo standard library is also [now open source in GitHub](https://github.com/modular/modular/tree/main/mojo/std/src). * **Learn to program GPUs** with [Mojo GPU Puzzles](https://builds.modular.com/puzzles)! This is a brand new site that offers a hands-on guide to mastering GPU programming with Mojo. Starting from basic concepts, you'll learn step-by-step how to program for GPUs by solving increasingly challenging puzzles. ### Documentation {#25-3-docs} We've restructured the documentation to unify MAX and Mojo documentation under the Modular Platform. We believe this improves content discovery with a simplified navigation and helps unify the platform story as a whole. We've also added the following new docs: * [REST API reference](/max/api/serve): Although it's not a new API (our serving library has supported OpenAI APIs for the last few versions), this now shows precisely which endpoints and body parameters we support. * [Speculative decoding](/max/serve/speculative-decoding): An introduction to using speculative decoding to reduce latency for LLMs. This feature is still in development. * [Offline inference](/max/serve/offline-inference): An introduction to our Python API for running inference with an LLM locally (without sending requests to a serving endpoint). * [Introduction to layouts](/mojo/manual/layout/layouts): A guide to working with dense multidimensional arrays on CPUs and GPUs, using new Mojo `layout` types that abstract-away complex memory layout patterns. ### `max` CLI {#25-3-max-cli} * Renamed the `max-pipelines` CLI tool to `max`. We recommend re-installing it as shown in the [`max` CLI docs](/max/cli/). * Remove previously deprecated `--use-gpu`, `--serialized_model_path`, `--save_to_serialized_model_path`, `--max_cache_batch_size` and `--huggingface-repo-id` options. * Move `InputContext`, `TextContext`, and `TextAndVisionContext` from `max.pipelines` to `max.pipelines.context`. ### MAX models {#25-3-models} * Added `Llama4ForConditionalGeneration` support, featuring new MoE layers. Currently, it is limited to text inputs. Run the model by calling: ```sh max generate --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --devices 0,1,2,3 ``` * Added support for running text generations using the Mistral 3 24B model. Run the model with: ```sh max generate --model-path mistralai/Mistral-Small-3.1-24B-Instruct-2503 --devices 0 ``` * Fixed empty textual outputs for certain Mistral models ([MAX issue 4193](https://github.com/modular/modular/issues/4193)). * Added support for loading a custom pipeline architecture by module. Using `--custom-architectures=folder/path/to/import:my_module` will lead to loading architectures from the file. The architectures must be exposed via an `ARCHITECTURES` variable in the file. Once loaded, a model can be run using the new architectures. The flag can be specified multiple times to load more modules. ### MAX Serve {#25-3-serve} * Moved from radix trie to hash based prefix caching implementation which has smaller CPU overheads. This improves performance particularly in workloads with high cache reuse rates. * Added experimental support for offloading KVCache to host memory via the `--enable-kvcache-swapping-to-host` and `--host-kvcache-swap-space-gb` flags. This allows for superior KVCache reuse through prefix caching in workloads where the reusable KVCache amount exceeds GPU VRAM. * Fixed the `usage.prompt_tokens` field in the OpenAI API Usage Info response. Previously this field was always set to Null, but now it correctly contains the number of prompt tokens in the request. * Switched from Python Multiprocessing Queue to ZeroMQ. This reduces latencies between frontend server process and model worker process related to networking. * Stray model workers on Linux now terminate more reliably when the parent process is killed. ### MAX Engine & Graph {#25-3-engine} #### Python API {#25-3-engine-python-api} * We now raise an error if there's a mismatch between the expected device of a weight on a graph and the device of the actual tensor data specified in [`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load). * Removed `output_device` argument from [`Model.execute()`](/max/api/python/engine#max.engine.Model.execute). * Removed the `copy_inputs_to_device` argument in [`Model.execute`](/max/api/python/engine#max.engine.Model.execute) to improve predictability of the API. Now `execute()` raises a `TypeError` if arguments are passed whose devices don't match the model. * Swapped the order of the `dtype` and `shape` fields of [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor). Previously, the arguments are ordered as `(shape, dtype)`. They are now swapped to `(dtype, shape)` to be in line with other tensor-like types. * Replaced some instances of [`Tensor.zeros`](/max/api/python/driver#max.driver.Tensor.zeros) with `Tensor.__init__` when the engine did not depend on the tensor being zero initialized. This elides the unnecessary memset to provide a minor performance improvement. * Added a new experimental [`Tensor.inplace_copy_from()`](/max/api/python/driver#max.driver.Tensor.inplace_copy_from). This allows users to copy the contents of one `Tensor` into another. * Made the default behavior of [`Weight`](/max/api/python/graph/Weight) as expecting the initial allocation on host. A transfer is then inserted to the target device and this value is returned when weights generate an MLIR value. This is done due to current conservative ownership around external weights. * Added the [`irfft`](/max/api/python/graph/ops/#max.graph.ops.irfft) op, which computes the inverse real fast fourier transform (FFT). * Added the [`argmax`](/max/api/python/graph/ops#max.graph.ops.argmax) op, which returns the index of the maximum value in an array or sequence. * Added the [`GroupNorm`](/max/api/python/nn/norm/group_norm) layer. * Switched layer names so that `max.nn` layers that are implemented with the deprecated `Layer` class are marked as "V1", and layers that are implemented with the new [`max.nn.Module`](/max/api/python/nn/module#max.nn.module.Module) are the default. That is, `max.nn.LinearV2` is now [`max.nn.Linear`](/max/api/python/nn/Linear), and the previous `max.nn.Linear` is now `max.nn.LinearV1`. * DeviceRefs in types/layers are in general expected to be explicit rather than implicit. #### Mojo API {#25-3-engine-mojo-api} * Removed some functionality from [`tensor.Tensor`](/mojo/kernels/extensibility/tensor/tensor/Tensor): * Serializing `Tensor` to disk (`Tensor.tofile(path)` and `Tensor.save(path)`). * Reading the serialized data back from disk (`Tensor.load(path)` and `Tensor.fromfile(path)`. * `rand` and `randn` methods have been removed. Use the ones in the Mojo standard library if you still need access for constructing a new `Tensor` with random elements based on a particular `TensorShape`. * **Deprecated the Mojo Driver, Graph, and Engine APIs** These APIs are not currently used internally. Instead, we build graphs using the Python APIs, and our engineering efforts have been focused on making that experience as robust and user-friendly as possible. As a result, the Mojo versions of these APIs have not kept pace with new features and language improvements. These APIs will be open sourced for the community before being removed. #### Custom ops API {#25-3-custom-ops} * You can now pass Mojo source package paths as [`Graph`](/max/api/python/graph/Graph) custom extensions. The Mojo code will be compiled automatically, no need to run `mojo package` manually as a prior step. Previously, only pre-compiled `.mojopkg` paths were accepted, requiring the Mojo code to be built as a prerequisite step before running a `Graph` with a custom op. Given a project structure like: ```text project |-- main.py \-- kernels |-- __init__.mojo \-- my_custom_op.mojo ``` You can construct a `Graph` in `main.py` using Mojo custom op kernels simply using: ```python g = Graph( ..., custom_extensions = [Path(__file__).parent / "kernels"] ) ``` A change to your Mojo source code defining a custom op will be reflected immediately the next time the `Graph` is constructed. * New [image\_pipeline example](https://github.com/modular/modular/tree/main/max/examples/custom_ops) that demonstrates sequencing custom ops together which modify an image, leaving data on the GPU for each op, before writing it back to CPU and disk. ### Kernels {#25-3-kernels} * More compute overlap is now enabled for Hopper GPUs. This allows finer-grained scheduling of kernel operations by analyzing producer-consumer patterns within a compute kernel. As a result, there is more kernel compute overlap, especially for compute-heavy kernels with data-dependent execution paths. ### GPU programming {#25-3-gpu-programming} * CUDA driver requirement reduced to version 12.4 and the NVIDIA driver to be version 550. Requiring these earlier driver versions allows MAX to be more easily deployed on AWS and GCP, since these are the default versions used by those cloud providers. * Added support for programming NVIDIA Jetson Orin GPUs (`sm_87`). Also see the [Mojo changelog of GPU changes](/mojo/changelog#gpu-changes). ### Mojo language {#25-3-mojo} * We recently open-sourced the rest of the Mojo standard library, including the `algorithm`, `benchmark`, `buffer`, `compile`, `complex`, `gpu`, and `layout` packages. [See it all in GitHub](https://github.com/modular/modular/tree/main/mojo/std/src). * We've also open sourced [all our MAX AI kernels](https://github.com/modular/modular/tree/main/max/kernels/src). This new library includes `kv_cache`, `layout`, `linalg`, `nn`, `nvml`, and `quantization`. For all the updates to the Mojo language, standard library, and tools, see the [Mojo changelog](/mojo/changelog). ## v25.2 (2025-03-25) * [Highlights](#25-2-highlights) * [MAX Serve](#25-2-serve) * [MAX models](#25-2-models) * [`max-pipelines` CLI](#25-2-pipelines-cli) * [MAX Engine](#25-2-engine) * [Driver APIs](#25-2-driver) * [Graph APIs](#25-2-graph) * [Custom ops](#25-2-custom-ops) * [Hopper Kernels](#25-2-hopper-kernels) * [GPU programming](#25-2-gpu-programming) * [Mojo](#25-2-mojo) * [Documentation](#25-2-documentation) ### ✨ Highlights {#25-2-highlights} * **Support for NVIDIA Hopper GPUs** MAX has been optimized to run on Hopper GPUs. For more information on MAX and NVIDIA's hardware, see the [MAX container](/max/container#recommended-cloud-instances) documentation. * **Multi-GPU support** MAX uses tensor parallelism to distribute work across multiple GPUs so you can run LLMs like [`Llama-3.3-70B-Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct), even with long context window. * **Expanded library of MAX models** We're rapidly growing our library of base model architectures that MAX can accelerate with MAX Serve (including `Phi3ForCausalLM`, `OlmoForCausalLM`, and `GraniteForCausalLM`). We also now support `GTPQ` for the Llama models. For more information, check out our [MAX model repository](https://builds.modular.com/?category=models). * **Advanced E2E optimizations for long context window** In flight batching, chunked prefill, and copy-on-write optimize the execution for prefix heavy and long context window scenario. * **GPU programming with Mojo** Lots of new APIs are now available to enable both low-level GPU programming and abstracted programming patterns that simplify the code required to write GPU kernels for your AI models. ### MAX Serve {#25-2-serve} * Extended MAX Serve batch scheduling to account for the prefix cache. The scheduler can now create larger batches when many prompt tokens are already cached, improving throughput up to 10% in some benchmarks. * Added support for in-flight batching, allowing token generation requests to be scheduled alongside context encoding requests to reduce inter-token latency. This behavior can be controlled by CLI argument `--enable-in-flight-batch`. * Added support for copy-on-write on KV blocks when using PagedAttention with Prefix Caching. This improves the prefix cache hit rate and prefill performance in some scenarios. * MAX Serve now supports `transformers` v.4.49.0, with a patch to avoid graph breaks when using `torch.compile()` on Llama models. * Added support for recording HTTP traffic out to a file for diagnostics or later replay. ### MAX models {#25-2-models} * Added support for executing `LlamaForCausalLM` architecture models on multiple GPUs. The model uses tensor parallelism automatically when passing multiple device IDs to the `--devices` CLI argument. Try running [`meta-llama/Llama-3.3-70B-Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) on 4 GPUs with the following example: ```sh max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \ --quantization-encoding bfloat16 \ --devices gpu:0,1,2,3 \ --prompt="Design a self-sustaining colony on Neptune's moon Triton with a myth/science fusion name, three quantum tech breakthroughs, one ethical debate, a neon-lit cultural ritual, and a hidden flaw—presented in bullet points." ``` * Added support for the `Phi3ForCausalLM` model architecture (such as [`microsoft/phi-4`](https://huggingface.co/microsoft/phi-4)). For example: ```sh max-pipelines generate \ --model-path microsoft/phi-4 \ --prompt "Write bubble sort in mojo" ``` * Added support for the `OlmoForCausalLM` model architecture (such as [`allenai/OLMo-1B-0724-hf`](https://huggingface.co/allenai/OLMo-1B-0724-hf)). For example: ```sh max-pipelines generate \ --model-path allenai/OLMo-1B-0724-hf \ --prompt "Write bubble sort in mojo" ``` * Added support for the `GraniteForCausalLM` model architecture (such as [`ibm-granite/granite-3.1-8b-instruct`](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)). For example: ```sh max-pipelines generate \ --model-path ibm-granite/granite-3.1-8b-instruct \ --prompt "Write bubble sort in mojo" ``` * Added support for: * [`microsoft/Phi-3.5-mini-instruct`](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) * [`microsoft/phi-4`](https://huggingface.co/microsoft/phi-4) * [`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) * [`LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct`](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct) * We now support GPTQ quantization for models that run on the GPU. This is handled transparently when the model weights are specified. For example, this runs Llama 3.1 8B using int4-quantized GPTQ weights: ```sh max-pipelines generate \ --model-path hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 \ --prompt "Why is the sky blue?" \ --max-batch-size 1 \ --max-length 10000 ``` This reduces the total memory consumption of this model from \~16 GB to \~5 GB, allowing the model to fit in the RAM smaller GPUs. * Model weights are now downloaded in parallel. * Added constraints on whitespace during [Structured Output](/max/serve/structured-output). This reduces tokens counts and improves model adherence. * Added jump ahead decoding during Structured Output. This auto-completes tokens when a singular path forward is identified, improving single completion times by up to \~20% for long prompts. * In the event of an unhandled exception, we now use the standard Python traceback format instead of using pretty-printed Rich tracebacks. * We now need to explicitly import `LLM` from [`max.entrypoints.llm`](/max/api/python/entrypoints) rather than the previous `max.entrypoints` import. * The `max.pipelines.dataprocessing.tokenizer` and `max.pipelines.dataprocessing.gguf_utils` modules have been removed. * The previously deprecated `PipelineConfig.architecture` field and its corresponding `--architecture` CLI argument have been removed. ### `max-pipelines` CLI {#25-2-pipelines-cli} * The `--devices` CLI argument now supports a comma-separated list of GPU IDs prefixed with `gpu:` like `--devices=gpu:0,1,2,3`. We no longer support the previous `--devices=gpu-` format. ```sh max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \ --quantization-encoding bfloat16 \ --devices gpu:0,1,2,3 \ --prompt="Design a self-sustaining colony on Neptune's moon Triton with a myth/science fusion name, three quantum tech breakthroughs, one ethical debate, a neon-lit cultural ritual, and a hidden flaw—presented in bullet points." ``` * Removed `--huggingface-repo-id` [PipelineConfig](/max/api/python/pipelines/config/#max.pipelines.config.PipelineConfig) option and CLI argument in favor of `--model-path`. * We consolidated `--model-path` and `-weight-path`. Valid `--weight-path` values now override `--model-path`, which handles both local and remote (Hugging Face) cases. If we cannot derive the weights from the `--weight-path`, we now fall back to the `--model-path`, which you must set explicitly. * Added `--huggingface-revision` option, to allow selecting a non-default branch or a specific commit in a Hugging Face model repository. ### MAX Engine {#25-2-engine} * The MAX graph compiler now has kernel caching. This is a significant improvement to our compilation pipeline. Here are some of the highlights: * Up to 28% faster compilation times when making iterative changes to models * Improved caching between different but similar models (up to 27% faster) * Lays foundation for future caching optimizations What does this mean for you? Faster development cycles! When you're working on model pipelines and making changes to the graph, the graph compiler will now intelligently reuse kernels that haven't changed, significantly reducing compilation times. The improvements are particularly noticeable during iterative development, with compilation times dropping from \~80s to \~57s in some cases of compiling Llama3.1-8B for 4 GPUs. Even when compiling different models from the same family (like Llama/Granite variants), you'll see significant speedups on subsequent compilations. ### Driver APIs {#25-2-driver} * Added `Accelerator.can_access(other: Device) -> bool` method to check if one device can directly access memory of another device. * Fixed a bug in `max.driver.tensor.load_max_tensor()` for `bfloat16` dtype, which would cause an error about mmap size being too large. * `max.driver.Tensor.item()` now works on any single-element tensor (previously restricted to rank-0 tensors). * Added [`Device.synchronize()`](/max/api/python/driver#max.driver.Device.synchronize), which ensures all operations on the device complete before returning. * Removed `MojoCallContextPtr` in favor of `DeviceContextPtr`. `MojoCallContextPtr` only contained a `DeviceContextPtr`, so this change directly exposes the `DeviceContextPtr`. Custom ops using `MojoCallContextPtr` now directly take a `DeviceContextPtr` argument: ```mojo @staticmethod fn execute[ type: DType, rank: Int ]( output: OutputTensor[type=type, rank=rank], input: InputTensor[type=type, rank=rank], ctx: MojoCallContextPtr, ): ``` becomes ```mojo @staticmethod fn execute[ type: DType, rank: Int ]( output: OutputTensor[type=type, rank=rank], input: InputTensor[type=type, rank=rank], ctx: DeviceContextPtr, ): ``` * You can now skip compiling a GPU kernel first before enqueueing it, and pass a function directly to `ctx.enqueue_function[func](...)`: ```mojo fn func(): print("Hello from GPU") @register("custom_op") struct CustomOp: @staticmethod fn execute(ctx: DeviceContextPtr) raises: var dev_ctx = ctx.get_device_context() dev_ctx.enqueue_function[func](grid_dim=1, block_dim=1) ``` However, if you're reusing the same function and parameters multiple times, this incurs some overhead of around 50-500 nanoseconds per enqueue. So you can still compile the function first and pass it to `ctx.enqueue_function` in this scenario: ```mojo var compiled_func = ctx.compile_function[func]() # Multiple kernel launches with the same function/parameters ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1) ``` * Changed `Accelerator` and `CPU` from factory methods that created `Device` objects in Python (which were accelerators and CPUs in the C++ implementation) to actual Python types. This change elevates the `Accelerator` and `CPU` type concepts to Python, making them types rather than methods. This allows type annotations in Python. For example, a list of accelerators used to be defined like this: ```python graph_devices: list[DeviceRef] ``` Now it can be defined like this: ```python graph_devices: list[Accelerator] ``` * Elementwise operations (e.g. `__add__`) have been removed from `Tensor` (that is, `tensor_internal.Tensor`). This `Tensor` type is being phased out; please reduce usage in favor of `LayoutTensor`. ### Graph APIs {#25-2-graph} * The `nn` package is now [`max.nn`](/max/api/python/nn/). * Added [`ops.chunk`](/max/api/python/graph#max.graphs.ops.chunk)) to support chunking tensors along an axis. * Added support for while loops with [`ops.while_loop`](/max/api/python/graph#max.graphs.ops.while_loop). * Added support for conditional execution with [`ops.cond`](/max/api/python/graph#max.graph.ops.cond). * Added axis reduction overloads for [`ops.min`](/max/api/python/graph/ops#max.graph.ops.min) and [`ops.max`](/max/api/python/graph/ops#max.graph.ops.max). For example; `ops.min(tensor, axis=-1)`. * The [`gelu()`](/max/api/python/graph/ops#max.graph.ops.gelu) function now accepts an `approximate` keyword. The keyword controls the `gelu` approximation with `none`, `tanh`, and `fast` approximations accepted. * Removed the `roundeven()` operation from the Python API. The [`round()`](/max/api/python/graph/ops#max.graph.ops.round) operation now has the same behavior as `roundeven()`, so there is no need for both to exist. * Added helpers to create analogous tensors from buffer types and vice versa. * Added `max.nn.Module`, a base class for writing layers and constructing networks of layers (e.g. using `max.nn.Sequential`). Currently, this class supports graph building by ensuring that all weight names are unique and systematically generated. This class also supports managing the weight values with the `module.state_dict()` and `module.load_state_dict()` methods. More functionality and documentation will be added in future releases. ### Custom ops {#25-2-custom-ops} * Changes have been made to the way that custom ops are registered: rather than using the `num_dps_outputs` attribute on `@compiler.register` to specify the number of outputs, that number is now inferred from the signature of the custom operation. Inputs to the operation now use the `InputTensor` type and outputs from the operation use `OutputTensor`, instead of the previous `ManagedTensorSlice` for both. This eliminates the need for a manual `num_dps_outputs` attribute, and makes it safer to work with these inputs and outputs by preventing accidental writes to input tensors. The new interface looks something like the following: ```mojo @compiler.register("add_one_custom") struct AddOneCustom: @staticmethod fn execute[ target: StringLiteral, ]( out: OutputTensor, x: InputTensor[type = out.type, rank = out.rank], ctx: DeviceContextPtr, ) raises: @parameter @always_inline fn elementwise_add_one[ width: Int ](idx: IndexList[x.rank]) -> SIMD[x.type, width]: return x.load[width](idx) + 1 foreach[elementwise_add_one, target=target](out, ctx) ``` * The `foreach` function now `raises` to be able to handle errors within an elementwise calculation. ### Hopper kernels {#25-2-hopper-kernels} State-of-the-Art Kernels in Mojo for H100/H200 GPUs * **Hopper Architecture Matrix Multiplication Kernels**: The implementation achieved performance comparable to NVIDIA's highly optimized cuBLAS library. These kernels take full advantage of the Tensor Cores in Hopper architecture GPUs to accelerate the fundamental matrix multiplication operations that underpin deep learning workloads. * **Multi-GPU AllReduce Implementation**: The AllReduce operation is critical for distributed inference across multiple GPUs, as it efficiently aggregates gradients. The Mojo implementation surpassed NVIDIA's NCCL library in performance benchmarks. This improvement reduces communication overhead during distributed inference. * **MAX Attention Kernel with Flash Attention 3:** This implementation incorporates the latest Flash Attention 3 algorithm and extends it, which significantly accelerates the computation of attention mechanisms in transformer models. The MAX attention kernel optimizes memory access patterns and computational steps, reducing both the memory footprint and execution time of attention operations. This is particularly important for LLMs where attention calculations represent a substantial portion of the computational workload. ### GPU programming {#25-2-gpu-programming} * Added the Mojo `max.driver` API to enable dispatching GPU functions from Mojo. Check out [examples for GPU programming in Mojo](https://github.com/modular/modular/tree/main/mojo/examples/gpu-functions), which use this new API. ### Mojo {#25-2-mojo} Mojo is a crucial component of the MAX stack that enables all of MAX's performance-oriented code across hardware. For all the updates to the Mojo language, standard library, and tools, see the [Mojo changelog](/mojo/changelog). ### Documentation {#25-2-documentation} New examples for writing custom ops: * [`fused_attention`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo) demonstrates complex GPU programming using MAX abstractions for a practical use in AI model development. * [`matrix_multiplication`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/matrix_multiplication.mojo) includes a series of progressive optimizations for matrix multiplications on GPUs. * [`histogram`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/histogram.mojo) shows how to implement the histogram pattern as a custom op. * New [examples for GPU programming in Mojo](https://github.com/modular/modular/tree/main/mojo/examples/gpu-functions) using the new MAX Driver API These use a Mojo programming model that should look familiar to CUDA C programmers, showing how to define and dispatch GPU functions within a single Mojo file. These examples recreate the first three samples from the popular textbook ["Programming Massively Parallel Processors"](https://www.amazon.com/Programming-Massively-Parallel-Processors-Hands/dp/0323912311), showing how basic concepts translate from CUDA into Mojo. Additionally, a Mandelbrot set calculation example that parallels a similar one in the existing custom ops examples. * New [MAX containers](/max/container/) available. For more information on the base and full MAX containers, see [Container contents](/max/container/#container-contents). ## v25.1.1 (2025-02-19) Fix performance issues in autoregressive models with paged attention by setting sensible default values for `--max-num-steps` that are platform-specific. ## v25.1 (2025-02-13) * [Highlights](#25-1-highlights) * [Documentation](#25-1-docs) * [MAX Serve](#25-1-serve) * [MAX models](#25-1-max-models) * [MAX Engine](#25-1-engine) * [Graph APIs](#25-1-graph) * [Pipeline APIs](#25-1-pipelines) * [GPU programming](#25-1-gpus) * [Mojo](#25-1-mojo) ### ✨ Highlights {#25-1-highlights} * **Custom ops for GPUs** Our new custom op API allows you to extend MAX Engine with new graph operations written in Mojo that execute on either CPU or GPU, providing full composability and extensibility for your models. See more in the section about [GPU programming](#25-1-gpus). * **Enhanced support for agentic workflows** MAX Serve now supports function calling, which allows you to instruct your model to interact with other systems, such as retrieve data and execute external tasks. [Learn more about function calling and tool use](/max/serve/function-calling). MAX Serve now supports structured output (also known as constrained decoding) for MAX models on GPU. This allows you to enforce the output format from a model using an input schema that defines the output structure. [Learn more about structured output](/max/serve/structured-output). * **Extended model architecture support** * MAX Serve now supports multimodal models that take both text and image inputs. For example, see [how to deploy Llama 3.2 Vision](/max/tutorials/deploy-llama-vision). * MAX Serve now supports text embedding models. Learn how to [deploy a text embedding model](/max/tutorials/run-embeddings-with-max-serve). * **New `max-pipelines` CLI tool** Instead of cloning our GitHub repo to access our latest GenAI models, you can instead install the `max-pipelines` CLI tool and quickly run an inference or deploy an endpoint. ### Documentation {#25-1-docs} New tutorials: * [Build custom ops for GPUs](/max/develop/build-custom-ops) * [Serverless GPU inference on Google Cloud Run](/max/tutorials/deploy-serverless-cloud-run) * [Generate image descriptions with Llama 3.2 Vision](/max/tutorials/deploy-llama-vision) * [Deploy a text embedding model](/max/tutorials/run-embeddings-with-max-serve) Other docs: * [Function calling and tool use](/max/serve/function-calling) * [Structured output](/max/serve/structured-output) * [Prefix caching with PagedAttention](/max/serve/prefix-caching) * `max-pipelines` CLI ### MAX Serve {#25-1-serve} * The `/v1/completions` REST endpoint now supports: * Pre-tokenized prompts. * Image inputs for multimodal models such as `Llama-3.2-11B-Vision-Instruct`. For an example, see [how to generate image descriptions with Llama 3.2 Vision](/max/tutorials/deploy-llama-vision). **Known issue:** You might receive faulty results because some parts of the text prompt get ignored for certain input combinations. We've identified the problem and will have a fix in a subsequent nightly release. * Function calling and tool use, which allows you to instruct your model to interact with other systems, such as retrieve data and execute external tasks. [Learn more about function calling and tool use](/max/serve/function-calling). * Structured output (also known as constrained decoding), which allows you to enforce the output format from a model using a JSON schema and the `response_format` field. To enable constrained decoding pass `--enable-structured-output` when running the server. However, this feature currently works for MAX models on GPU only (support for PyTorch models and CPU is in progress). [Learn more about structured output](/max/serve/structured-output). * Added support for the `/v1/embeddings` API endpoint, allowing you to generate vector representations using embedding models. See how to [deploy a text embedding model](/max/tutorials/run-embeddings-with-max-serve). * Max Serve can evict requests when the number of available pages in the PagedAttention KVCache is limited. Before, the KV manager would throw an OOM error when a batch that cannot fit in the cache was scheduled. ### MAX models {#25-1-max-models} * Added the `max-pipelines` CLI tool that simplifies the process to run inference with GenAI models (specified with a Hugging Face repo ID) and deploy them to a local endpoint with MAX Serve. Previously, running or serving these models required cloning the [modular/max](https://github.com/modular/max) GitHub repo and then running commands such as `magic run llama3`. These model-specific commands like `llama3` and `replit` commands have been removed. They're now standardized and subsumed by flags like `--model-path` in the `max-pipelines` tool. Arguments such as `--max-length` and `--weight-path` are also still supported by `max-pipelines`. To view a list of supported model architectures from Hugging Face, run `max-pipelines list`. * Added support for PagedAttention, which improves memory efficiency by partitioning the KV cache into smaller blocks, reducing fragmentation and enabling larger inference batches. You can enable it with `--cache-strategy=paged` and `--kv-cache-page-size` with a value that's a multiple of 128. * Added support for prefix caching in all cases where PagedAttention is supported. This allows for more efficient usage of KVCache and improved prefill performance for workloads with common prefixes. You can enable it by setting `--enable-prefix-caching`. For more information, see [Prefix caching with PagedAttention](/max/serve/prefix-caching). * Batch size and max length are now inferred from available memory and the HF Models' default values for max length, respectively. If a configuration leads to an OOM, then we provide recommendations (to the best of our ability) to the user to fit the model into memory. * Added support for heterogeneous KV caches for multi-modal models, such as Llama Vision, which cache different KV states for self and cross attention layers. * Added support for embedding models, starting with MPNet. For example: ```shell max-pipelines generate \ --model-path=sentence-transformers/all-mpnet-base-v2 \ --prompt="Encode this sentence." ``` Also see [how to deploy a text embedding model](/max/tutorials/run-embeddings-with-max-serve). * Added support for image and text multimodal models: * `max-pipelines generate` now accepts image input with `--image_url`. * Added an experimental Pixtral pipeline you can run as follows: ```shell max-pipelines generate \ --model-path=mistral-community/pixtral-12b \ --prompt="What is in this image? [IMG]" \ --image_url=http://picsum.photos/1024/1024 ``` The pipeline is automatically used for all models implementing the `LlavaForConditionalGeneration` architecture. The implementation currently has a limit of one image. We plan support an arbitrary number of images of mixed sizes soon. * Added an experimental Llama Vision pipeline you can run as follows: ```shell max-pipelines generate \ --model-path=meta-llama/Llama-3.2-11B-Vision-Instruct \ --prompt="<|image|><|begin_of_text|>What is in this image?" \ --image_url=http://picsum.photos/1024/1024 ``` The pipeline is automatically used for all models implementing the `MllamaForConditionalGeneration` architecture. Note: This model is gated and requires that you set the [`HF_TOKEN`](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hftoken) environment variable. See [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct). * See [how to generate image descriptions with Llama 3.2 Vision](/max/tutorials/deploy-llama-vision). * Added support for the `Qwen2ForCausalLM` model architecture (such as `Qwen/Qwen2.5-7B-Instruct`). For example: ```shell max-pipelines generate \ --model-path=Qwen/Qwen2.5-7B-Instruct \ --prompt="Write bubble sort in python" \ --quantization-encoding bfloat16 ``` * Added support for offline batched inference for text-based LLMs, allowing you to load a model and run inference with a batch of inputs directly from Python, instead of relying on an HTTP interface. For an example, see [`examples/offline-inference/basic.py`](https://github.com/modular/modular/blob/main/examples/offline-inference/basic.py). * The `--max-cache-batch-size` flag has been deprecated in favor of `--max-batch-size`. Using `--max-cache-batch-size` now emits a deprecation warning and will stop working in a future release. * The `--use-gpu` flag has been deprecated in favor of `--devices=cpu`, `--devices=gpu`, or `--devices=gpu-0,gpu-1,...`. If the device isn't specified, the model runs on the first available GPU, or CPU if no GPUs are available. ### MAX Engine {#25-1-engine} * Improved internal kernel compilation speed 1.5 - 4X across different models. We've revamped our GPU compilation process so that all kernels in a program are compiled together into a single LLVM module, then split into separate kernels afterward. This ensures shared code between kernel entry points is only compiled once. For example, we observe a 3.7x speed up for Llama3.1-8b GPU startup time. * Improved initial model execution speed on NVIDIA GPUs. Instead of compiling to PTX and performing just-in-time compilation during runtime, we now generate CUBIN binaries directly. While this increases initial compilation time, it significantly improves execution speed. * The kernels have been further tuned for performance on NVIDIA A100 GPUs. #### Graph APIs {#25-1-graph} * You can now write custom operations (ops) in Mojo, and add them to a graph constructed in Python, using [`custom()`](/max/api/python/graph/ops#max.graph.ops.custom) and [`inplace_custom()`](/max/api/python/max/graph/ops#max.graph.ops.inplace_custom). For more detail, see the section below about [GPU programming](#25-1-gpus). * Cached compiled MAX graphs that make use of custom operations now get invalidated when the implementation of the custom operations change. * [`Graph.add_weight()`](/max/api/python/graph/Graph#max.graph.Graph.add_weight) now takes an explicit `device` argument. This enables explicitly passing GPU-resident weights to [`session.load()`](/max/api/python/engine#max.engine.InferenceSession.load) via the weights registry to initialize the model. * [`max.graph.Weight`](/max/api/python/graph/Weight) now inherits from `TensorValue`, allowing you to call `weight.cast()` or `weight.T`. As such, the [`TensorValue`](/max/api/python/graph/TensorValue#max.graph.TensorValue) no longer accepts `Weight` for the `value` argument. #### Pipeline APIs {#25-1-pipelines} * [`TextTokenizer.new_context()`](/max/api/python/pipelines/tokenizer#max.pipelines.tokenizer.TextTokenizer.new_context) now supports tool definitions passed through its `request` argument (via `TokenGeneratorRequest.tools`). It also now supports JSON schemas passed through its `request` argument (via [`TokenGeneratorRequest.response_format`](/max/api/python/pipelines/interfaces/#max.pipelines.interfaces.TokenGeneratorRequest.response_format)). * Removed the default `num_steps` value for [`TokenGenerator.next_token()`](/max/api/python/pipelines/interfaces/#max.pipelines.interfaces.TokenGenerator.next_token), ensuring users pass a value, reducing the potential for silent errors. * [`KVCacheStrategy`](/max/api/python/pipelines/kv_cache/cache_params#max.pipelines.kv_cache.cache_params.KVCacheStrategy) now defaults to `MODEL_DEFAULT`. As opposed to the previous setting which always used the "continuous" caching strategy, KV caching strategy is now defaulted on an architecture-specific basis to ensure the most optimized caching strategy is used. * The [`Linear`](/max/api/python/nn/Linear) layer now has a `create()` class method that automatically creates specializations of `Linear` for non-quantized, k-quant, or GPTQ layers. * Added [`nn.Conv1D`](/max/api/python/nn/conv#max.nn.conv.Conv1D) for audio models like Whisper. #### GPU programming {#25-1-gpus} This release includes all new APIs to program on GPUs. The way to write code for GPUs is to create custom operations with GPU functions that you can load into a MAX graph. This foundational API includes a few key components: * Mojo APIs to write custom op functions: * The [`@compiler.register`](/max/api/mojo-decorators/compiler-register) decorator is applied to a Mojo struct that implements a custom op in an `execute()` function—for either CPU or GPU—and a `shape()` function that defines the custom op's output tensor. * The [`max.tensor`](/mojo/kernels/extensibility/tensor/) package adds essential Mojo APIs for writing custom ops, such as: * The [`foreach()`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach) function, which efficiently executes an element-wise computation in parallel on either a GPU or CPU. * The [`ManagedTensorSlice`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/ManagedTensorSlice) type defines the input and output tensors for the custom op. * Python APIs to load custom ops into a model: * The [`custom()`](/max/api/python/graph/ops#max.graph.ops.custom) and `inplace_custom()` functions allow you to add the previously-defined Mojo custom op to a MAX graph written in Python. * The [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession) constructor accepts the custom op implementation as a [Mojo package](/mojo/manual/packages#mojo-packages) in the `custom_extensions` argument. For more detail, see the [tutorial to build custom ops for GPUs](/max/develop/build-custom-ops), or check out this [simple example of a custom op](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/add_custom.mojo). Additionally, we've added a new [`gpu` package](/mojo/std/gpu/) to the Mojo standard library that provides low-level programming constructs for working with GPUs. These APIs let you do things that you can't currently do with the high-level `foreach()` abstraction above. The Mojo `gpu` APIs allow you to manually manage interaction between the CPU host and GPU device, manage memory between devices, synchronize threads, and more. For some examples, see [`vector_addition.mojo`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/vector_addition.mojo) and [`top_k.mojo`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/top_k.mojo). ### Mojo {#25-1-mojo} Mojo is a crucial component of the MAX stack that enables all of MAX's performance-oriented code across hardware. For all the updates to the Mojo language, standard library, and tools, see the [Mojo changelog](/mojo/changelog). ## v24.6 (2024-12-17) This is a huge update that offers a first look at our serving library for MAX on GPUs! * [Highlights](#24-6-highlights) * [Documentation](#24-6-docs) * [MAX Serve](#24-6-serve) * [MAX models](#24-6-models) * [MAX Engine](#24-6-engine) * [Driver APIs](#24-6-driver-api) * [Graph compiler](#24-6-graph-compiler) * [Graph APIs](#24-6-graph-api) * [Custom op registration](#24-6-custom-ops) * [Numeric kernels](#24-6-kernels) * [Mojo](#24-6-mojo) Also check out our [blog post introducing MAX 24.6](https://www.modular.com/blog/introducing-max-24-6-a-gpu-native-generative-ai-platform). ### ✨ Highlights {#24-6-highlights} * **MAX Engine on GPUs preview** We're excited to share a preview of MAX Engine on GPUs. We've created a few tutorials that demonstrate MAX's ability to run GenAI models with our next-generation MAX graph compiler on NVIDIA GPU architectures (including A100, A10, L4, and L40 GPUs). You can experience it today by [deploying Llama 3 on an A100 GPU](/max/tutorials/max-serve-local-to-cloud). * **MAX Serve preview** This release also includes an all-new serving interface called MAX Serve. It's a Python-based serving layer that supports both native MAX models when you want a high-performance deployment, and off-the-shelf PyTorch LLMs from Hugging Face when you want to explore and experiment—all with GPU support. It provides an OpenAI-compatible REST endpoint for inference requests, and a Prometheus-compatible metrics endpoint. You can use a `magic` command to start a local server , or use our ready-to-deploy MAX container to start an endpoint in the cloud. Try it now [with an LLM from Hugging Face](/max/tutorials/max-serve-local-to-cloud). * **Upgraded MAX models** As we continue to build our Python-based MAX Graph API that allows you to build high-performance GenAI models, we've made a ton of performance improvements to the existing models and added a few new models to our GitHub repo. All the Python-based MAX models now support GPUs and broad model architectures. For example, [`llama3`](https://github.com/modular/modular/tree/main/max/pipelines/architectures/llama3) adds compatibility for the LlamaForCausalLM family, which includes over 20,000 model variants and weights on Hugging Face. ### Documentation {#24-6-docs} New tutorials: * [Deploy Llama 3 on GPU with MAX Serve](/max/tutorials/max-serve-local-to-cloud) * [Deploy Llama 3.1 on GPU-powered Kubernetes clusters](/max/tutorials/deploy-max-serve-on-kubernetes) * [Get started with MAX Graph in Python](/max/tutorials/get-started-with-max-graph-in-python) Other new docs: * [MAX container](/max/container) * [Benchmark MAX Serve](https://github.com/modular/modular/tree/main/benchmark) Also, our documentation is now available for **MAX nightly builds**! If you're building with a nightly release, you can switch to see the nightly docs using a toggle to the right of the search bar. ### MAX Serve {#24-6-serve} This release includes a preview of our Python-based serving library called MAX Serve. It simplifies the process to deploy your own inference server with consistent and reliable performance. MAX Serve currently includes the following features: * Deploys locally and to the cloud with our [MAX container image](/max/container), or with the `magic` CLI. * An OpenAI-compatible server with streaming `/chat/completion` and `/completion` endpoints for LLM inference requests. * Prometheus-compatible [metrics endpoint](/max/container#metrics) with LLM KPIs (TTFT and ITL) for monitoring and evaluating performance. * Supports most `TextGeneration` Hugging Face Hub models. * Multiprocess HTTP/model worker architecture to maximize CPU core utilization by distributing multiple incoming requests across multiple processes, ensuring both high throughput and responsiveness. * Continuous heterogeneous batching to combine multiple incoming requests into a single inference (no waiting to fill a batch size) and improve total throughput. There's much more still in the works for MAX Serve, but you can try it today with our tutorials to [Deploy Llama 3 on GPU with MAX Serve](/max/tutorials/max-serve-local-to-cloud). **Known issues:** * While this release is enough to support typical chatbot applications, this release does not yet support the function-calling portion of the OpenAI API specification needed to enable robust agentic workflows. * Sampling is still limited and doesn't currently respect temperature or other sampling-related API request input. * Structured generation is not supported. * Support for multi-modal models is still nascent. ### MAX models {#24-6-models} All of our Python-based GenAI [models on GitHub](https://github.com/modular/modular/tree/main/max/pipelines/architectures) now support GPUs! As we add more models, we're also building a robust set of libraries and infrastructure that make it easier to build and deploy a growing library of LLMs. Some of which is available in a new [`max.pipelines`](/max/api/python/pipelines/) package and some of it is alongside the [models on GitHub](https://github.com/modular/modular/tree/main/max/pipelines/architectures). Here are just some of the highlights: * Deep integration with the Hugging Face ecosystem for a quick-to-deploy experience, such as using HF Model Hub tools to fetch config files, support for weights in [safetensor](https://github.com/huggingface/safetensors) format, support for HF tokenizers, and more. (We also support GGUF weight formats.) * Expanded set of model abstractions for use by different LLM architectures: * Attention layers (including highly optimized implementations with configurable masking, like [`AttentionWithRope`](https://github.com/modular/modular/tree/main/max/nn/attention/attention_with_rope.py)). The optimized attention layers include variants that accept an attention mask. More memory-efficient variants that don't take a mask instead take a "mask functor" argument to the kernel, which implements masking without materializing a mask by computing a mask value from input coordinates on the fly. * Transformers such as [`Transformer` and `TransformerBlock`](https://github.com/modular/modular/tree/main/max/nn/transformer/transformer.py). These include an initial implementation of ragged tensors—tensors for which each dimension can have a different size, avoiding the use of padding tokens by flattening a batch of sequences of differing lengths. * Common layers such as [`RMSNorm`](https://github.com/modular/modular/tree/main/max/nn/norm/rms_norm.py) , [`Embedding`](https://github.com/modular/modular/tree/main/max/nn/embedding.py), and [`Sequential`](https://github.com/modular/modular/tree/main/max/nn/sequential.py). * KV cache management helpers, like [`ContinuousBatchingKVCacheManager`](/max/api/python/pipelines/kv_cache/continuous_batching_cache#max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager). * Low-level wrappers over optimized kernels like [`fused_qk_ragged_rope`](https://github.com/modular/modular/tree/main/max/nn/kernels.py). These are custom fused kernels that update the KV cache in place. Although they are custom, they reuse the underlying kernel implementation by passing in lambda functions used to retrieve inputs and write to outputs in place. * Added generalized interfaces for text generation such as [`TokenGenerator`](/max/api/python/pipelines/interfaces#max.pipelines.interfaces.TokenGenerator) and [`PipelineModel`](/max/api/python/pipelines/pipeline#max.pipelines.pipeline.PipelineModel), which provide modularity within the models and serving infrastructure. Also added a plug-in mechanism ([`PipelineRegistry`](/max/api/python/pipelines/registry#max.pipelines.registry.PipelineRegistry)) to more quickly define new models, tokenizers, and other reusable components. For example, anything that conforms to [`TokenGenerator`](/max/api/python/pipelines/interfaces#max.pipelines.interfaces.TokenGenerator) can be served using the LLM infrastructure within MAX Serve. We then used this interface to create the following: * An optimized [`TextGenerationPipeline`](/max/api/python/pipelines/pipeline#max.pipelines.pipeline.TextGenerationPipeline) that can be combined with any compatible graph and has powerful performance features like graph-based multi-step scheduling, sampling, KV cache management, ragged tensor support, and more. * A generic [`HFTextGenerationPipeline`](/max/api/python/pipelines/hf_pipeline#max.pipelines.hf_pipeline.HFTextGenerationPipeline) that can run any Hugging Face model for which we don't yet have an optimized implementation in eager mode. * Models now accept weights via a weights registry, which is passed to the [`session.load()`](/max/api/python/engine#max.engine.InferenceSession.load) method's `weights_registry` argument. The decoupling of weights and model architecture allows implementing all of the different fine-tunes for a given model with the same graph. Furthermore, because the underlying design is decoupled, we can later expose the ability to compile a model once and swap weights out on the fly, without re-compiling the model. * Added generic implementations of common kernels, which allow you to plug-in different batching strategies (ragged or padded), KV cache management approaches (continuous batching), masking (causal, sliding window, etc.), and position encoding (RoPE or ALIBI) without having to re-write any kernel code. (More about this in a future release.) * Multi-step scheduling to run multiple token-generation steps on GPU before synchronizing to the CPU. **Updated models:** * Significant performance upgrades for [Llama 3](https://github.com/modular/modular/tree/main/max/pipelines/architectures/llama3), and expanded compatibility with the `LlamaForCausalLM` models family. For example, it also supports Llama 3.2 1B and 3B text models. **New models:** * [Mistral NeMo](https://github.com/modular/modular/tree/main/max/pipelines/architectures/mistral) (and other `MistralForCausalLM` models) * [Replit Code V1.5 3B](https://github.com/modular/modular/tree/main/max/pipelines/architectures/replit) **Known issues:** * The Q4 quantized models currently work on CPU only. * Using a large setting for `top-k` with the Llama 3.1 model may lead to segmentation faults for certain workloads when run on NVIDIA GPUs. This should be resolved in the latest nightly MAX builds. * The models currently use a smaller default context window than the `max_seq_len` specified in the Hugging Face configuration files for a given model. This can be manually adjusted by setting the `--max-length` parameter to the desired context length when serving a model. * Some variants of the supported core models (like `LlamaForCausalLM` with different number of heads, head sizes, etc.) might not be fully optimized yet. We plan to fully generalize our implementations in a future release. ### MAX Engine {#24-6-engine} MAX Engine includes a lot of the core infrastructure that enables MAX to accelerate AI models on any hardware, such as the graph compiler, runtime, kernels, and the APIs to interact with it all, and it all works without external dependencies such as PyTorch or CUDA. This release includes a bunch of performance upgrades to our graph compiler and runtime. We've added support for NVIDIA GPU architectures (including A100, A10, L4, and L40 GPUs), and built out new infrastructure so we can quickly add support for other GPU hardware. **Engine API changes:** * [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession) now accepts a `custom_extensions` constructor argument, same as `load()`, to specify model extension libraries. * The [`Model`](/max/api/python/engine#max.engine.Model) object is now callable to run an inference. **Breaking changes**: * `Model.execute()` signature changed to support GPUs. * The [`execute()`](/max/api/python/engine#max.engine.Model.execute) function currently doesn't accept keyword arguments. Instead you can pass tensors as a [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor), `int`, `float`, `bool`, [`np.generic`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic), or [`DLPackArray`](/max/api/python/driver#max.driver.DLPackArray) ([DLPack](https://github.com/dmlc/dlpack)). Note that both PyTorch and NumPy arrays implement the DLPack protocol, which means you can also pass either of those types to `execute()`. * [`execute_legacy()`](/max/api/python/engine#max.engine.Model.execute_legacy) preserves the semantics of `execute()` with support for keyword arguments to help with migration, but will be removed in a future release. `execute_legacy()` doesn't support GPUs. * Calling `execute()` with positional arguments still works the same. #### Driver APIs {#24-6-driver-api} MAX Driver (the [`max.driver`](/max/api/python/driver) module) is a new component of MAX Engine that's still a work in progress. It provides primitives for working with heterogeneous hardware systems (GPUs and CPUs), such as to allocate on-device memory, transfer data between host and device, query device stats, and more. It's a foundation on which other components of MAX Engine operate (for example, `InferenceEngine` now uses [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) to handle model inputs and outputs). **Driver API changes:** * Added `CUDA()` device to open an NVIDIA GPU. * Added support for fp16 and bfloat16 dtypes. * Expanded functionality for `max.driver.Device`, with new class methods and properties. We are still working on building this out to support more accelerator features. * [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) (and the `InferenceSession.load()` argument `weights_registry` ) now supports zero-copy interoperability with NumPy arrays and PyTorch tensors, using [DLPack](https://github.com/dmlc/dlpack) / [`DLPackArray`](/max/api/python/driver#max.driver.DLPackArray). * [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) has new methods, such as `from_dlpack()`, `element_size()` , `to()`, `to_numpy()`, `view()`, `zeros()`, and more. MAX Driver APIs are still changing rapidly and not yet ready for general use. We'll publish more documentation in a future release. **Known issues:** * MAX Driver is currently limited to managing just one NVIDIA GPU at a time (it does not yet support multi-GPU). It also does not yet support remote devices. * DLPack support is not complete. For example, streams are not yet supported. #### Graph compiler {#24-6-graph-compiler} When you load a model into MAX Engine, the graph compiler is the component that inspects and optimizes all graph operations (ops) to deliver the best run time performance on each device. This release includes various graph compiler improvements: * Major extensions to support NVIDIA GPUs (and other devices in the future), including async copies and caching of JIT'd kernels. * The runtime now performs scheduling to enable GPU compute overlap with the CPU. * New transformations to the Mojo kernels to enable a number of optimizations, including specialization on tensor dimensions, specialization on target hardware, specialization on non-tensor dimension input to kernels, automatic kernel fusion between operators, and more. * New algebraic simplifications and algorithms for ops such as horizontal fusion of matrix multiplications. * New CPU-side primitives for device management that are automatically transformed and optimized to reduce overhead (MAX does not need to use things like CUDA Graphs). * Updated memory planning to preallocate device memory (hoist computation from inference runtime to initialization time) and reduce per-inference overhead. #### Graph APIs {#24-6-graph-api} The graph compiler is also exposed through the MAX Graph APIs (the [`max.graph`](/max/api/python/graph/) package), which allow you to build high-performance GenAI models in Python. **Graph API changes:** * Python stack traces from model execution failures now include a trace to the original op-creation, allowing for easier debugging during development. * The [`max.graph`](/max/api/python/graph/) APIs now include preliminary support for symbolic algebraic expressions using [`AlgebraicDim`](/max/api/python/graph/type#max.graph.type.AlgebraicDim), enabling more powerful support for checked dynamic shapes. This allows `-Dim("x") - 4`. Furthermore, the algebraic expressions simplify to a canonical form, so that for example `-Dim("x") - 4 == -(Dim("x") + 4)` holds. * More advanced dtype promotion now allows [`TensorValue`](/max/api/python/graph/TensorValue) math operators to just work when used with NumPy arrays and python primitives. * [`TensorValue`](/max/api/python/graph/TensorValue) has new methods, such as `broadcast_to()`, `cast()`, `flatten()`, `permute()`, and more. * Added [`BufferValue`](/max/api/python/graph/BufferValue), which allows for device-resident tensors that are read and mutated within the graph. * [`DType`](/max/api/python/dtype#max.dtype.DType) has new methods/properties, `align`, `size_in_bytes`, and `is_float()`. * [`Value`](/max/api/python/graph/Value) constructor accepts more types for `value`. * [`TensorValue`](/max/api/python/graph/TensorValue) constructor accepts more types for `value`. * [`TensorValue.rebind()`](/max/api/python/graph/TensorValue#max.graph.TensorValue.rebind) accepts a new `message` argument. **Breaking changes:** * [`Graph.add_weight()`](/max/api/python/graph/Graph#max.graph.Graph.add_weight) now accepts [`Weight`](/max/api/python/graph/Weight#max.graph.Weight) and returns [`TensorValue`](/max/api/python/graph/TensorValue). [`Weight`](/max/api/python/graph/Weight#max.graph.Weight) is essentially a named placeholder for a tensor that knows its name, dtype, shape, and optionally device and quantization encoding. `Graph.add_weight()` stages an op in the graph that is populated by a named weight in the weights registry passed to `session.load`. * The [`Weight`](/max/api/python/graph/Weight#max.graph.Weight) constructor arguments changed; added `align` , `dtype` , and `shape`; removed `assign` , `filepath`, `offset`, and `value`. * The `ops.scalar()` method was removed along with the `is_static()` and `is_symbolic()` methods from all `graph.type` objects. * Instead of `ops.scalar()`, use [`ops.constant()`](/max/api/python/graph/ops#max.graph.ops.constant). * Instead of `is_static()` and `is_symbolic()`, use `isinstance(dim, SymbolicDim)` and `isinstance(dim, StaticDim)`. The MAX Graph APIs are not ready for general use but you can [experiment with it now by following this tutorial](/max/tutorials/get-started-with-max-graph-in-python). We'll add more documentation when we finish some API redesigns. #### Custom op registration {#24-6-custom-ops} Although the APIs to write custom operators (ops) isn't ready for general use, this release includes a significant redesign that lays the groundwork. You might notice some associated APIs in this release and more APIs in the nightlies, so here's a little about the work in progress: * The custom op APIs will allow you to extend MAX Engine with new ops written in Mojo, providing full composability and extensibility for your models. It's the exact same API we use to write MAX Engine's built-in ops such as `matmul`. That means your custom ops can benefit from all our compiler optimization features such as kernel fusion—your ops are treated the same as all the ops included "in the box." * The new API requires far less adornment at the definition site to enable the MAX model compiler to optimize custom ops along with the rest of the graph (compared to our previous version that used `NDBuffer`). * Custom ops support "destination passing style" for tensors. * The design composes on top of Mojo's powerful meta programming, as well as the kernel libraries abstractions for composable kernels. We'll publish more documentation when the custom op API is ready for general use. Check out the MAX repo's `nightly` branch to see the latest [custom op examples](https://github.com/modular/modular/tree/main/max/examples/custom_ops). **Known issues:** * Custom ops don't have type or lifetime checking. They also don't reason about mutability. Expect lots of sharp corners and segfaults if you hold them wrong while we improve this! #### Numeric kernels {#24-6-kernels} The GPU kernels for MAX Engine are built from the ground up in Mojo with no dependencies on external vendor code or libraries. This release includes the following kernel improvements: * AttenGen: a novel way to express attention pattern that's able to express different attention masks, score functions, as well as caching strategies. * State-of-the-art matrix multiplication algorithms with optimizations such as the following: * Pipelining and double-buffering to overlap data transfer and computation and to hide memory access latency (for both global and shared memory). * Thread swizzling to avoid shared memory bank conflicts associated with tensor core layouts. * Block swizzling to increase L2 cache locality. * SplitK/StreamK GEMM algorithms: divides the computation along the shared K dimension into smaller matrices which can then be executed independently on streaming multiprocessors (such as CUDA cores). These algorithms are ideal for matrices with large K dimension but small M dimension. * Large context length MHA: uses SplitK/StreamK to implement the attention mechanism and eliminate the need of a huge score matrix, which drastically reduces memory usage/traffic to enable large context length. * DualGemm: accelerates the multi-layer perceptron (MLP) layers where the left-hand side (LHS) is shared between two matrix multiplications. **Known issues:** * The MAX kernels are optimized for bfloat16 on GPUs. * Convolution on GPU is not performance optimized yet. * Although v24.6 technically runs on H100, it doesn't include performance-optimized kernels for that device yet and it isn't recommended. ### Mojo {#24-6-mojo} Mojo is a crucial component of the MAX stack that enables all of MAX's performance-oriented code across hardware. For all the updates to the Mojo language, standard library, and tools, see the [Mojo changelog](/mojo/changelog#v246-2024-12-17). ## v24.5 (2024-09-13) ### ✨ Highlights * Mojo and MAX are magical! We've created a new package and virtual environment manager, `magic`, for MAX and Mojo. * New [Llama3.1 pipeline](https://github.com/modular/modular/tree/main/max/pipelines/architectures) built with the new MAX Graph Python API. * We have not one, but two new Python APIs that we're introducing in this release: * [MAX Graph Python API](#max-graph-python-api) * [MAX Driver Python API](#max-driver-python-api) ### ⭐️ New * Added `repeat_interleave` graph op. * Added caching for MAX graph models. This means that graph compilation is cached and the executable model is retrieved from cache on the 2nd and subsequent runs. Note that the model cache is architecture specific and isn't portable across different targets. * Support for Python 3.12. #### MAX Graph Python API This Python API will ultimately provide the same low-level programming interface for high-performance inference graphs as the Mojo API. As with the Mojo API, it's an API for graph-building only, and it does not implement support for training. You can take a look at how the API works in the [MAX Graph Python API reference](/max/api/python/graph/). #### MAX Driver Python API The MAX Driver API allows you to interact with devices (such as CPUs and GPUs) and allocate memory directly onto them. With this API, you interact with this memory as tensors. Note that this API is still under development, with support for non-host devices, such as GPUs, planned for a future release. To learn more, check out the [MAX Driver Python APIreference](/max/api/python/driver). #### MAX C API New APIs for adding torch metadata libraries: * `M_setTorchMetadataLibraryPath` * `M_setTorchMetadataLibraryPtr` ### 🦋 Changed #### MAX Engine performance * Compared to v24.4, MAX Engine v24.5 generates tokens for Llama an average of 15%-48% faster. #### MAX C API Simplified the API for adding torch library paths, which now only takes one path per API call, but can be called multiple times to add paths to the config: * `M_setTorchLibraries` -> `M_setTorchLibraryPath` ### ⚠️ Deprecated * The `max` command line tool is no longer supported and will be removed in a future release. ### ❌ Removed * Dropped support for Ubuntu 20.04. If you're using Ubuntu, we currently support Ubuntu 22.04 LTS only. * Dropped support for Python 3.8. * Removed built-in PyTorch libraries from the max package. See the [FAQ](/max/faq) for information on supported torch versions. ## v24.4 (2024-06-07) ### 🔥 Legendary * MAX is now available on macOS! [Try it now](/max). * New quantization APIs for MAX Graph. You can now build high-performance graphs in Mojo that use the latest quantization techniques, enabling even faster performance and more system compatibility for large models. Learn more in the guide to [quantize your graph weights](/max/graph/quantize). ### ⭐️ New #### MAX Mojo APIs * Added AI pipeline examples in the `max` repo, with Mojo implementations for common transformer layers, including quantization support. * New Llama3 pipeline built with MAX Graph. * New Replit Code pipeline built with MAX Graph. * New TinyStories pipeline (based on TinyLlama) that offers a simple demo of the MAX Graph quantization API. * Added `max.graph.checkpoint` package to save and load model weights. All weights are stored in a `TensorDict`. You can save and load a `TensorDict` to disk with `save()` and `load()` functions. * Added MAX Graph quantization APIs: * Added quantization encodings `BFloat16Encoding`, `Q4_0Encoding`, `Q4_KEncoding`, and `Q6_KEncoding`. * Added the `QuantizationEncoding` trait so you can build custom quantization encodings. * Added `Graph.quantize()` to create a quantized tensor node. * Added `qmatmul()` to perform matrix-multiplication with a float32 and a quantized matrix. * Added some MAX Graph ops: * `avg_pool()` * `max_pool()` * `conv2d()` * `conv3d()` * `layer_norm()` * `tile()` * `select()` * Added a `layer()` context manager and `current_layer()` function to aid in debugging during graph construction. For example: ```mojo with graph.layer("foo"): with graph.layer("bar"): print(graph.current_layer()) # prints "foo.bar" x = graph.constant[DType.int64](1) graph.output(x) ``` This adds a path `foo.bar` to the added nodes, which will be reported during errors. * Added `format_system_stack()` function to format the stack trace, which we use to print better error messages from `error()`. * Added `TensorMap.keys()` to get all the tensor key names. #### MAX C API Miscellaneous new APIs: * `M_cloneCompileConfig()` * `M_copyAsyncTensorMap()` * `M_tensorMapKeys()` and `M_deleteTensorMapKeys()` * `M_setTorchLibraries()` ### 🦋 Changed #### MAX Mojo API * `EngineNumpyView.data()` and `EngineTensorView.data()` functions that return a type-erased pointer were renamed to `unsafe_ptr()`. * `TensorMap` now conforms to `CollectionElement` trait to be copyable and movable. * `custom_nv()` was removed, and its functionality moved into `custom()` as a function overload, so it can now output a list of tensor symbols. ## v24.3 (2024-05-02) ### 🔥 Legendary * You can now write custom ops for your models with Mojo! Learn more about [MAX extensibility](/max/develop/custom-ops). ### 🦋 Changed * Added support for named dynamic dimensions. This means you can specify when two or more dimensions in your model's input are dynamic but their sizes at run time must match each other. By specifying each of these dimension sizes with a name (instead of using `None` to indicate a dynamic size), the MAX Engine compiler can perform additional optimizations. See the notes below for the corresponding API changes that support named dimensions. * Simplified all the APIs to load input specs for models, making them more consistent. #### MAX Engine performance * Compared to v24.2, MAX Engine v24.3 shows an average speedup of 10% on PyTorch models, and an average 20% speedup on dynamically quantized ONNX transformers. #### MAX Graph API The `max.graph` APIs are still changing rapidly, but starting to stabilize. * `AnyMoType` renamed to `Type`, `MOTensor` renamed to `TensorType`, and `MOList` renamed to `ListType`. * Removed `ElementType` in favor of using `DType`. * Removed `TypeTuple` in favor of using `List[Type]`. * Removed the `Module` type so you can now start building a graph by directly instantiating a `Graph`. * Some new ops in `max.ops`, including support for custom ops. See how to [create a custom op in MAX Graph](/max/develop/build-custom-ops). #### MAX Engine Python API * Redesigned [`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load) to replace the confusing `options` argument with a `custom_ops_path` argument. As a result, `CommonLoadOptions`, `TorchLoadOptions`, and `TensorFlowLoadOptions` have all been removed. * [`TorchInputSpec`](/max/api/python/engine#max.engine.TorchInputSpec) now supports named dynamic dimensions (previously, dynamic dimension sizes could be specified only as `None`). This lets you tell MAX which dynamic dimensions are required to have the same size, which helps MAX better optimize your model. #### MAX Engine Mojo API * `InferenceSession.load_model()` was renamed to `load()`. * Redesigned `InferenceSession.load()` to replace the confusing `config` argument with a `custom_ops_path` argument for use when [loading a custom op](/max/develop/build-custom-ops), and an `input_specs` argument for use when loading TorchScript models. Doing so removed `LoadOptions` and introduced the new `InputSpec` type to define the input shape/type of a model (instead of `LoadOptions`). * New `ShapeElement` type to allow for named dynamic dimensions (in `InputSpec`). * `max.engine.engine` module was renamed to `max.engine.info`. #### MAX Engine C API * [`M_newTorchInputSpec()`](/max/api/c/pytorch/config#m_newtorchinputspec) now supports named dynamic dimensions (via new `dimNames` argument). ### ❌ Removed * Removed TensorFlow support in the MAX SDK, so you can no longer load a TensorFlow SavedModel for inference. However, TensorFlow is still available for enterprise customers. We removed TensorFlow because industry-wide TensorFlow usage has declined significantly, especially for the latest AI innovations. Removing TensorFlow also cuts our package size by over 50% and accelerates the development of other customer-requested features. If you have a production use-case for a TensorFlow model, please [contact us](https://www.modular.com/request-demo). * Removed the Python `CommonLoadOptions`, `TorchLoadOptions`, and `TensorFlowLoadOptions` classes. See note above about `InferenceSession.load()` changes. * Removed the Mojo `LoadOptions` type. See the note above about `InferenceSession.load()` changes. ## v24.2.1 (2024-04-11) * You can now import more MAX Graph functions from `max.graph.ops` instead of using `max.graph.ops.elementwise`. For example: ```mojo from max.graph import ops var relu = ops.relu(matmul) ``` ## v24.2 (2024-03-28) * MAX Engine now supports TorchScript models with dynamic input shapes. No matter what the input shapes are, you still need to [specify the input specs](/max/model-formats#specify-torchscript-input-specs) for all TorchScript models. * The Mojo standard library is now open source! Read more about it in [this blog post](https://www.modular.com/blog/the-next-big-step-in-mojo-open-source). * And, of course, lots of Mojo updates, including implicit traits, support for keyword arguments in Python calls, a new `List` type (previously `DynamicVector`), some refactoring that might break your code, and much more. For details, see the [Mojo changelog](/mojo/changelog#v242-2024-03-28). ## v24.1.1 (2024-03-18) This is a minor release that improves error reports. ## v24.1 (2024-02-29) The first release of the MAX platform is here! 🚀 This is a **preview version** of the MAX platform. That means it is not ready for production deployment and designed only for local development and evaluation. Because this is a preview, some API libraries are still in development and subject to change, and some features that we previously announced are not quite ready yet. But there is a lot that you can do in this release! This release includes our flagship developer tools, currently for **Linux only**: * **MAX Engine**: Our state-of-the-art graph compiler and runtime library that executes models from PyTorch and ONNX, with incredible inference speed on a wide range of hardware. * API libraries in Python, C, and Mojo to run inference with your existing models. [See the API references](/max/api). * The `max benchmark` tool, which runs MLPerf benchmarks on any compatible model without writing any code. * The `max visualize` tool, which allows you to visualize your model in Netron after partially lowering in MAX Engine. * An early look at the [MAX Graph API](/max/model-formats#max-graph), our low-level library for building high-performance inference graphs. * **MAX Serving**: A preview of our serving wrapper for MAX Engine that provides full interoperability with existing AI serving systems (such as Triton) and that seamlessly deploys within existing container infrastructure (such as Kubernetes). * A Docker image that runs MAX Engine as a backend for NVIDIA Triton Inference Server. * **Mojo**: The world's first programming language built from the ground-up for AI developers, with cutting-edge compiler technology that delivers unparalleled performance and programmability for any hardware. * The latest version of Mojo, the standard library, and the `mojo` command line tool. These are always included in MAX, so you don't need to download any separate packages. * The Mojo changes in each release are often quite long, so we're going to continue sharing those in the existing [Mojo changelog](/mojo/changelog). Additionally, we've started a new [GitHub repo for MAX](https://github.com/modular/max), where we currently share a bunch of code examples for our API libraries, including some large model pipelines. You can also use this repo to [report issues with MAX](https://github.com/modular/modular/issues/new/choose). ### Model Architecture Support * Added support for the following model architectures: * `OlmoForCausalLM` (such as `allenai/OLMo-1B-0724-hf`) * `GraniteForCausalLM` (such as `ibm-granite/granite-3.1-8b-instruct`) * `Phi3ForCausalLM` (for Microsoft Phi-3 models) * `Qwen2ForCausalLM` (such as Qwen2 models) Example usage: ```sh max-pipelines generate \ --model-path allenai/OLMo-1B-0724-hf \ --prompt "Write bubble sort in mojo" ``` * The `max.pipelines.dataprocessing.tokenizer` and `max.pipelines.dataprocessing.gguf_utils` modules have been removed. * The previously deprecated `PipelineConfig.architecture` field and its corresponding `--architecture` CLI argument have been removed. ### `max-pipelines` CLI * The `--devices` CLI argument now supports a comma-separated list of GPU IDs prefixed with `gpu:` like `--devices=gpu:0,1,2,3`. We no longer support the previous `--devices=gpu-` format. ```sh max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \ --quantization-encoding bfloat16 \ --devices gpu:0,1,2,3 \ --prompt="Design a self-sustaining colony on Neptune's moon Triton with a myth/science fusion name, three quantum tech breakthroughs, one ethical debate, a neon-lit cultural ritual, and a hidden flaw—presented in bullet points." ``` * Removed `--huggingface-repo-id` PipelineConfig option and CLI argument in favor of `--model-path`. * Consolidated `-model-path` and `-weight-path`. If valid `-weight-path`(s) are provided, they'll now override `--model-path`, which in turn handles both local and remote (Hugging Face) cases. If we cannot derive the weights from the `--weight-path`(s), we'll now fall back to the `--model-path`, which has to be set explicitly by the user. * Added `--huggingface-revision` option, to allow selecting a non-default branch or a specific commit in a Hugging Face model repository. --- ## max benchmark Runs comprehensive benchmark tests on an active model server to measure performance metrics including throughput, latency, and resource utilization. For a complete walkthrough, see the tutorial to [benchmark MAX on a GPU](/max/deploy/benchmark). Before running this command, make sure the model server is running, via [`max serve`](/max/cli/serve). For example, here's how to benchmark the `google/gemma-3-27b-it` model already running on localhost: ```sh max benchmark \ --model google/gemma-3-27b-it \ --backend modular \ --endpoint /v1/chat/completions \ --num-prompts 50 \ --dataset-name arxiv-summarization \ --arxiv-summarization-input-len 12000 \ --max-output-len 1200 ``` When it's done, you'll see the results printed to the terminal. By default, it sends inference requests to `localhost:8000`, but you can change that with the `--host` and `--port` arguments. If you want to save the results, add the `--save-result` option, which creates a JSON file in the local path with the following naming convention: ```bash {backend}-{request_rate}qps-{model_name}-{timestamp}.json ``` But you can specify the file name with `--result-filename` and change the directory with `--result-dir`. Instead of passing all these benchmark options, you can instead pass a configuration file. See [Configuration file](#benchmark-configuration-file) below. :::note The `max benchmark` command is a convenient packaging for our open-source [`benchmark_serving.py`](https://github.com/modular/modular/tree/main/max/python/max/benchmark#benchmark-max) script and accepts all the same options. ::: ## Usage ```sh max benchmark [OPTIONS] ``` ## Options This list of options is not exhaustive. For more information, run `max benchmark --help` or see the [benchmarking script source code](https://github.com/modular/modular/tree/main/max/python/max/benchmark). * Backend configuration: * `--backend`: Choose from `modular` (MAX `v1/completions` endpoint), `modular-chat` (MAX `v1/chat/completions` endpoint), or `vllm` (vLLM) * `--model`: Hugging Face model ID or local path * Load generation: * `--num-prompts`: Number of prompts to process (`int`, default: `500`) * `--request-rate`: Request rate in requests/second (`int`, default: `inf`) * `--seed`: The random seed used to sample the dataset (`int`, default: `0`) * Serving options * `--base-url`: Base URL of the API service * `--endpoint`: Specific API endpoint (`/v1/completions` or `/v1/chat/completions`) * `--tokenizer`: Hugging Face tokenizer to use (can be different from model) * `--dataset-name`: (Required; default:`sharegpt`) Specifies which type of benchmark dataset to use. This determines the dataset class and processing logic. See [Datasets](#datasets) below. * `--dataset-path`: Path to a local dataset file that overrides the default dataset source for the specified `dataset-name`. The file format must match the expected format for the specified `dataset-name` (such as JSON for `axolotl`, JSONL for `obfuscated-conversations`, plain text for `sonnet`). * Additional options * `--collect-gpu-stats`: Report GPU utilization and memory consumption for both NVIDIA and AMD GPUs. Only works when running `max benchmark` on the same instance as the server. * `--save-results`: Saves results to a local JSON file. * LoRA benchmarking options The benchmark script supports testing LoRA adapter performance for supported models and target modules: * `--num-loras`: Number of LoRA adapters to test. If > 0, test LoRA adapters will be generated. * `--lora-rank`: LoRA rank (r parameter) for generated adapters. Controls the dimension of the low-rank decomposition. * `--lora-output-dir`: Directory to save generated LoRA adapters. Defaults to `/tmp/loras`. * `--lora-paths`: Paths to existing LoRA adapters to use instead of generating new ones. * `--lora-request-ratio`: Ratio of requests to send with LoRA adapters (0.0-1.0). For example, 0.5 means 50% of requests use LoRA. * `--max-num-loras`: Maximum number of LoRA adapters cached on GPU. This should match the server configuration. * `--lora-target-modules`: List of module names to apply LoRA to when generating random test adapters (e.g., `q_proj`, `k_proj`, `v_proj`, `o_proj`). Only used when `--num-loras` > 0 and generating adapters (not when using existing `--lora-paths`). * `--config-file`: Path to a YAML file containing benchmark configuration. The configuration file is a YAML file that contains key-value pairs for all your benchmark configurations (as a replacement for individual command line options). See [Configuration file](#benchmark-configuration-file) below. ### Datasets The `--dataset-name` option supports several dataset names/formats you can use for benchmarking: * `arxiv-summarization` - Research paper summarization dataset containing academic papers with abstracts for training summarization models, from Hugging Face Datasets. * `axolotl` - Local dataset in Axolotl format with conversation segments labeled as human/assistant text, from Hugging Face Datasets. * `code_debug` - Long-context code debugging dataset containing code with multiple choice debugging questions for testing long-context understanding, from Hugging Face Datasets. * `obfuscated-conversations` - Local dataset with obfuscated conversation data. You must pair this with the `--dataset-path` option to specify the local JSONL file. * `random` - Synthetically generated random dataset that creates random token sequences with configurable input/output lengths and distributions. * `sharegpt` - Conversational dataset containing human-AI conversations for chat model evaluation, from Hugging Face Datasets. * `sonnet` - Poetry dataset using local text files containing poem lines, from Hugging Face Datasets. * `vision-arena` - Vision-language benchmark dataset containing images with associated questions for multimodal model evaluation, from Hugging Face Datasets. You can override the default dataset source for any of these using the `--dataset-path` option (except for generated datasets like `random`), but you must always specify a `--dataset-name` so the tool knows how to process the dataset format. ### Configuration file {#benchmark-configuration-file} The `--config-file` option allows you to specify a YAML file containing all your benchmark configurations, as a replacement for individual command line options. Simply define all the configuration options (corresponding to the `max benchmark` command line options) in a YAML file, all nested under the `benchmark_config` key. :::caution In the YAML file, the properties **must use `snake_case` names** instead of using the hyphenated names from the command line options. For example, `--num-prompts` becomes `num_prompts`. ::: For instance, instead of specifying all configurations in the command line like this: ```sh max benchmark \ --model google/gemma-3-27b-it \ --backend modular \ --endpoint /v1/chat/completions \ --host localhost \ --port 8000 \ --num-prompts 50 \ --dataset-name arxiv-summarization \ --arxiv-summarization-input-len 12000 \ --max-output-len 1200 ``` Create this configuration file: ```yaml title="gemma-benchmark.yaml" benchmark_config: model: google/gemma-3-27b-it backend: modular endpoint: /v1/chat/completions host: localhost port: 8000 num_prompts: 50 dataset_name: arxiv-summarization arxiv_summarization_input_len: 12000 max_output_len: 1200 ``` And then run the benchmark by passing that file: ```sh max benchmark --config-file gemma-benchmark.yaml ``` For more config file examples, see our [benchmark configs on GitHub](https://github.com/modular/modular/tree/main/max/python/max/benchmark/configs). For a walkthrough of setting up an endpoint and running a benchmark, see the [quickstart guide](/max/get-started). ## Output Here's an explanation of the most important metrics printed upon completion: * **Request throughput**: Number of complete requests processed per second * **Input token throughput**: Number of input tokens processed per second * **Output token throughput**: Number of tokens generated per second * **TTFT**: Time to first token—the time from request start to first token generation * **TPOT**: Time per output token—the average time taken to generate each output token * **ITL**: Inter-token latency—the average time between consecutive token or token-chunk generations If `--collect-gpu-stats` is set, you'll also see these: * **GPU utilization**: Percentage of time during which at least one GPU kernel is being executed * **Peak GPU memory used**: Peak memory usage during benchmark run --- ## max encode Converts input text into embeddings for semantic search, text similarity, and NLP applications. For example: ```bash max encode \ --model sentence-transformers/all-MiniLM-L6-v2 \ --prompt "Convert this text into embeddings" ``` ## Usage ```shell max encode [OPTIONS] ``` ## Options ### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--cache-strategy ` The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
**Options:**
KVCacheStrategy.MODEL\_DEFAULT | KVCacheStrategy.PAGED
### `--ce-delay-ms ` Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler. ### `--chat-template ` Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used. ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--custom-architectures ` Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register. ### `--data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--defer-resolve, --no-defer-resolve` Whether to defer resolving the pipeline config. ### `--device-graph-capture, --no-device-graph-capture` Enable device graph capture/replay for graph execution. ### `--device-memory-utilization ` The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size. ### `--devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--draft-config-file ` ### `--draft-data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--draft-devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-force-download, --no-draft-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--draft-huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--draft-quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--draft-rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--draft-section-name ` ### `--draft-served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--draft-trust-remote-code, --no-draft-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--draft-use-subgraphs, --no-draft-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--draft-vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--draft-weight-path ` Optional path or url of the model weights to use. ### `--enable-chunked-prefill, --no-enable-chunked-prefill` Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens. ### `--enable-echo, --no-enable-echo` Whether the model should be built with echo capabilities. ### `--enable-in-flight-batching, --no-enable-in-flight-batching` When enabled, prioritizes token generation by batching it with context encoding requests. ### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host` Whether to swap paged KVCache blocks to host memory when device blocks are evicted. ### `--enable-lora, --no-enable-lora` Enables LoRA on the server. ### `--enable-min-tokens, --no-enable-min-tokens` Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached. ### `--enable-overlap-scheduler, --no-enable-overlap-scheduler` Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force. ### `--enable-penalties, --no-enable-penalties` Whether to apply frequency and presence penalties to the model’s output. ### `--enable-prefix-caching, --no-enable-prefix-caching` Whether to enable prefix caching for the paged KVCache. ### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode` When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler. ### `--enable-structured-output, --no-enable-structured-output` Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to. ### `--enable-variable-logits, --no-enable-variable-logits` Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes. ### `--ep-size ` The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes. ### `--execute-empty-batches, --no-execute-empty-batches` Whether the scheduler should execute empty batches. ### `--force, --no-force` Skip validation of user provided flags against the architecture’s required arguments. ### `--force-download, --no-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--gpu-profiling ` Whether to enable GPU profiling of the model.
**Options:**
GPUProfilingMode.OFF | GPUProfilingMode.ON | GPUProfilingMode.DETAILED
### `--host-kvcache-swap-space-gb ` The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled. ### `--huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--kv-cache-format ` Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn. ### `--kv-cache-page-size ` The number of tokens in a single page in the paged KVCache. ### `--kvcache-ce-watermark ` Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions. ### `--lora-paths ` List of statically defined LoRA paths. ### `--max-batch-input-tokens ` The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. ### `--max-batch-size ` Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity. ### `--max-batch-total-tokens ` Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited. ### `--max-length ` Maximum sequence length of the model. ### `--max-lora-rank ` Maximum rank of all possible LoRAs. ### `--max-num-loras ` The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage. ### `--max-num-steps ` The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models). ### `--max-queue-size-tg ` Maximum number of requests in decode queue. By default, this is max\_batch\_size. ### `--min-batch-size-tg ` Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler. ### `--model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--num-speculative-tokens ` The number of speculative tokens. ### `--num-warmups ` Number of warmup iterations to run before the final timed run.
**Default:**
`0`
### `--pipeline-role ` Whether the pipeline should serve both a prefill or decode role or both.
**Options:**
PipelineRole.PrefillAndDecode | PipelineRole.PrefillOnly | PipelineRole.DecodeOnly
### `--pool-embeddings, --no-pool-embeddings` Whether to pool embedding outputs. ### `--prompt ` The text prompt to use for further generation. ### `--quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--speculative-method ` The speculative decoding method to use.
**Options:**
SpeculativeMethod.STANDALONE | SpeculativeMethod.EAGLE | SpeculativeMethod.MTP
### `--trust-remote-code, --no-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--use-experimental-kernels ` Enables using experimental mojo kernels with max serve. The kernels could be unstable or incorrect. ### `--use-legacy-module, --no-use-legacy-module` Whether to use the legacy Module architecture (default=True for backward compatibility). Set to False to use the new Module-based architecture when available. ### `--use-subgraphs, --no-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--use-vendor-blas ` Enables using vendor BLAS libraries (cublas/hipblas/etc) with max serve. Currently, this just replaces matmul calls. ### `--vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--weight-path ` Optional path or url of the model weights to use. ### `--zmq-endpoint-base ` Prefix for ZMQ endpoints used for IPC. This ensures unique endpoints across MAX Serve instances on the same host. Example: lora\_request\_zmq\_endpoint = f”{zmq\_endpoint\_base}-lora\_request”. --- ## max generate Generates output from a given model and prompt, without using an endpoint—primarily for debugging and testing purposes. For example: ```bash max generate \ --model google/gemma-3-12b-it \ --max-length 1024 \ --max-new-tokens 500 \ --top-k 40 \ --temperature 0.7 \ --seed 42 \ --prompt "Explain quantum computing" ``` :::note You can adjust parameters like `--max-batch-size` and `--max-length` depending on your system's available resources such as GPU memory. ::: For more information on how to use the `generate` command with vision models, see [Image to text](/max/inference/image-to-text). ## Usage ```shell max generate [OPTIONS] ``` ## Options ### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--cache-strategy ` The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
**Options:**
KVCacheStrategy.MODEL\_DEFAULT | KVCacheStrategy.PAGED
### `--ce-delay-ms ` Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler. ### `--chat-template ` Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used. ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--custom-architectures ` Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register. ### `--data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--defer-resolve, --no-defer-resolve` Whether to defer resolving the pipeline config. ### `--detokenize, --no-detokenize` Whether to detokenize the output tokens into text. ### `--device-graph-capture, --no-device-graph-capture` Enable device graph capture/replay for graph execution. ### `--device-memory-utilization ` The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size. ### `--devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--draft-config-file ` ### `--draft-data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--draft-devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-force-download, --no-draft-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--draft-huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--draft-quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--draft-rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--draft-section-name ` ### `--draft-served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--draft-trust-remote-code, --no-draft-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--draft-use-subgraphs, --no-draft-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--draft-vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--draft-weight-path ` Optional path or url of the model weights to use. ### `--enable-chunked-prefill, --no-enable-chunked-prefill` Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens. ### `--enable-echo, --no-enable-echo` Whether the model should be built with echo capabilities. ### `--enable-in-flight-batching, --no-enable-in-flight-batching` When enabled, prioritizes token generation by batching it with context encoding requests. ### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host` Whether to swap paged KVCache blocks to host memory when device blocks are evicted. ### `--enable-lora, --no-enable-lora` Enables LoRA on the server. ### `--enable-min-tokens, --no-enable-min-tokens` Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached. ### `--enable-overlap-scheduler, --no-enable-overlap-scheduler` Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force. ### `--enable-penalties, --no-enable-penalties` Whether to apply frequency and presence penalties to the model’s output. ### `--enable-prefix-caching, --no-enable-prefix-caching` Whether to enable prefix caching for the paged KVCache. ### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode` When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler. ### `--enable-structured-output, --no-enable-structured-output` Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to. ### `--enable-variable-logits, --no-enable-variable-logits` Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes. ### `--ep-size ` The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes. ### `--execute-empty-batches, --no-execute-empty-batches` Whether the scheduler should execute empty batches. ### `--force, --no-force` Skip validation of user provided flags against the architecture’s required arguments. ### `--force-download, --no-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--frequency-penalty ` The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text. ### `--gpu-profiling ` Whether to enable GPU profiling of the model.
**Options:**
GPUProfilingMode.OFF | GPUProfilingMode.ON | GPUProfilingMode.DETAILED
### `--host-kvcache-swap-space-gb ` The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled. ### `--huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--ignore-eos` If True, the response will ignore the EOS token, and continue to generate until the max tokens or a stop string is hit. ### `--image_url ` Images to include along with prompt, specified as URLs. The images are ignored if the model does not support image inputs. ### `--kv-cache-format ` Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn. ### `--kv-cache-page-size ` The number of tokens in a single page in the paged KVCache. ### `--kvcache-ce-watermark ` Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions. ### `--lora-paths ` List of statically defined LoRA paths. ### `--max-batch-input-tokens ` The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. ### `--max-batch-size ` Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity. ### `--max-batch-total-tokens ` Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited. ### `--max-length ` Maximum sequence length of the model. ### `--max-lora-rank ` Maximum rank of all possible LoRAs. ### `--max-new-tokens ` Maximum number of new tokens to generate during a single inference pass of the model. ### `--max-num-loras ` The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage. ### `--max-num-steps ` The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models). ### `--max-queue-size-tg ` Maximum number of requests in decode queue. By default, this is max\_batch\_size. ### `--min-batch-size-tg ` Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler. ### `--min-new-tokens ` Minimum number of tokens to generate in the response. ### `--min-p ` Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in \[0, 1]. Set to 0 to disable this. ### `--model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--num-speculative-tokens ` The number of speculative tokens. ### `--num-warmups ` Number of warmup iterations to run before the final timed run.
**Default:**
`0`
### `--pipeline-role ` Whether the pipeline should serve both a prefill or decode role or both.
**Options:**
PipelineRole.PrefillAndDecode | PipelineRole.PrefillOnly | PipelineRole.DecodeOnly
### `--pool-embeddings, --no-pool-embeddings` Whether to pool embedding outputs. ### `--presence-penalty ` The presence penalty to apply to the model’s output. A positive value will penalize new tokens that have already appeared in the generated text at least once. ### `--prompt ` The text prompt to use for further generation. ### `--quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--repetition-penalty ` The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in the generated text at least once. ### `--rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--seed ` Seed for the random number generator. ### `--served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--speculative-method ` The speculative decoding method to use.
**Options:**
SpeculativeMethod.STANDALONE | SpeculativeMethod.EAGLE | SpeculativeMethod.MTP
### `--stop ` A list of detokenized sequences that can be used as stop criteria when generating a new sequence. Can be specified multiple times. ### `--stop-token-ids ` A list of token ids that are used as stopping criteria when generating a new sequence. Comma-separated integers. ### `--temperature ` Controls the randomness of the model’s output; higher values produce more diverse responses. ### `--top-k ` Limits the sampling to the K most probable tokens. This defaults to 255. For greedy sampling, set to 1. ### `--top-p ` Only use the tokens whose cumulative probability is within the top\_p threshold. This applies to the top\_k tokens. ### `--trust-remote-code, --no-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--use-experimental-kernels ` Enables using experimental mojo kernels with max serve. The kernels could be unstable or incorrect. ### `--use-legacy-module, --no-use-legacy-module` Whether to use the legacy Module architecture (default=True for backward compatibility). Set to False to use the new Module-based architecture when available. ### `--use-subgraphs, --no-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--use-vendor-blas ` Enables using vendor BLAS libraries (cublas/hipblas/etc) with max serve. Currently, this just replaces matmul calls. ### `--vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--weight-path ` Optional path or url of the model weights to use. ### `--zmq-endpoint-base ` Prefix for ZMQ endpoints used for IPC. This ensures unique endpoints across MAX Serve instances on the same host. Example: lora\_request\_zmq\_endpoint = f”{zmq\_endpoint\_base}-lora\_request”. --- ## max (Cli) The `max` command line tool allows you to create an OpenAI-compatible endpoint with a simple `max serve` command. It also includes a command to benchmark your endpoint using built-in datasets or your own dataset. To install the `max` CLI, install the `modular` package as shown in the [install guide](/max/packages#install). ## Usage ```shell max [OPTIONS] COMMAND [ARGS]... ``` ## Options ### `--log-level ` Set logging level explicitly (ignored if –verbose or –quiet is used).
**Options:**
DEBUG | INFO | WARNING | ERROR
### `--version` Show the MAX version and exit. ## Commands [`benchmark`](benchmark.md): Run benchmark tests on a serving model. [`encode`](encode.md): Encode text input into model embeddings. [`generate`](generate.md): Generate text using the specified model. [`list`](list.md): List available pipeline configurations and… [`serve`](serve.md): Start a model serving endpoint for inference. [`warm-cache`](warm-cache.md): Load and compile the model to prepare caches. --- ## max list List available pipeline configurations and models. This command displays information about all registered pipelines and their configurations. Output can be formatted as human-readable text or JSON. ## Usage ```shell max list [OPTIONS] ``` ## Options ### `--json` Print the list of pipelines options in JSON format.
**Default:**
`False`
--- ## max serve Launches a model server with an OpenAI-compatible endpoint. Just specify the model as a Hugging Face model ID or a local path. For example: ```bash max serve \ --model google/gemma-3-12b-it \ --devices gpu:0 \ --max-batch-size 8 \ --device-memory-utilization 0.9 ``` For details about the endpoint APIs provided by the server, see [the MAX REST API reference](/max/api/serve). The `max` CLI also supports loading custom model architectures through the `--custom-architectures` flag. This allows you to extend MAX’s capabilities with your own model implementations: ```bash max serve \ --model google/gemma-3-12b-it \ --custom-architectures path/to/module1:module1 \ --custom-architectures path/to/module2:module2 ``` :::note Custom architectures The `--custom-architectures` flag allows you to load custom pipeline architectures from your own Python modules. You can set the `ARCHITECTURES` variable containing the architecture definitions. Each entry in `--custom-architectures` can be specified in two formats: * A raw module name; for example: `my_module`. * An import path followed by a colon and the module name; for example: `folder/path/to/import:my_module`. The `ARCHITECTURES` variable in your module should be a list of implementations that conform to the [SupportedArchitecture](/max/api/python/pipelines/registry#max.pipelines.lib.registry.SupportedArchitecture) interface. These will be registered with the MAX pipeline registry automatically. ::: :::note Quantization encoding When using GGUF models, quantization encoding formats are automatically detected. If no `--quantization-encoding` is specified, MAX Serve automatically detects and uses the first encoding option from the repository. If quantization encoding is provided, it must align with the available encoding options in the repository. If the repository contains multiple quantization formats, specify which encoding type you want to use with the `--quantization-encoding` parameter. ::: ## Usage ```shell max serve [OPTIONS] ``` ## Options ### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--cache-strategy ` The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
**Options:**
KVCacheStrategy.MODEL\_DEFAULT | KVCacheStrategy.PAGED
### `--ce-delay-ms ` Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler. ### `--chat-template ` Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used. ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--custom-architectures ` Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register. ### `--data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--defer-resolve, --no-defer-resolve` Whether to defer resolving the pipeline config. ### `--device-graph-capture, --no-device-graph-capture` Enable device graph capture/replay for graph execution. ### `--device-memory-utilization ` The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size. ### `--devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--draft-config-file ` ### `--draft-data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--draft-devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-force-download, --no-draft-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--draft-huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--draft-quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--draft-rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--draft-section-name ` ### `--draft-served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--draft-trust-remote-code, --no-draft-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--draft-use-subgraphs, --no-draft-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--draft-vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--draft-weight-path ` Optional path or url of the model weights to use. ### `--enable-chunked-prefill, --no-enable-chunked-prefill` Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens. ### `--enable-echo, --no-enable-echo` Whether the model should be built with echo capabilities. ### `--enable-in-flight-batching, --no-enable-in-flight-batching` When enabled, prioritizes token generation by batching it with context encoding requests. ### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host` Whether to swap paged KVCache blocks to host memory when device blocks are evicted. ### `--enable-lora, --no-enable-lora` Enables LoRA on the server. ### `--enable-min-tokens, --no-enable-min-tokens` Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached. ### `--enable-overlap-scheduler, --no-enable-overlap-scheduler` Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force. ### `--enable-penalties, --no-enable-penalties` Whether to apply frequency and presence penalties to the model’s output. ### `--enable-prefix-caching, --no-enable-prefix-caching` Whether to enable prefix caching for the paged KVCache. ### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode` When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler. ### `--enable-structured-output, --no-enable-structured-output` Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to. ### `--enable-variable-logits, --no-enable-variable-logits` Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes. ### `--ep-size ` The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes. ### `--execute-empty-batches, --no-execute-empty-batches` Whether the scheduler should execute empty batches. ### `--force, --no-force` Skip validation of user provided flags against the architecture’s required arguments. ### `--force-download, --no-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--gpu-profiling ` Whether to enable GPU profiling of the model.
**Options:**
GPUProfilingMode.OFF | GPUProfilingMode.ON | GPUProfilingMode.DETAILED
### `--headless` Run only the dispatcher service and model worker without the API server.
**Default:**
`False`
### `--host-kvcache-swap-space-gb ` The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled. ### `--huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--kv-cache-format ` Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn. ### `--kv-cache-page-size ` The number of tokens in a single page in the paged KVCache. ### `--kvcache-ce-watermark ` Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions. ### `--log-prefix ` Optional prefix to add to all log messages for this server instance. ### `--lora-paths ` List of statically defined LoRA paths. ### `--max-batch-input-tokens ` The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. ### `--max-batch-size ` Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity. ### `--max-batch-total-tokens ` Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited. ### `--max-length ` Maximum sequence length of the model. ### `--max-lora-rank ` Maximum rank of all possible LoRAs. ### `--max-num-loras ` The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage. ### `--max-num-steps ` The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models). ### `--max-queue-size-tg ` Maximum number of requests in decode queue. By default, this is max\_batch\_size. ### `--min-batch-size-tg ` Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler. ### `--model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--num-speculative-tokens ` The number of speculative tokens. ### `--pipeline-role ` Whether the pipeline should serve both a prefill or decode role or both.
**Options:**
PipelineRole.PrefillAndDecode | PipelineRole.PrefillOnly | PipelineRole.DecodeOnly
### `--pool-embeddings, --no-pool-embeddings` Whether to pool embedding outputs. ### `--port ` Port to run the server on. ### `--pretty-print-config` Pretty Print Entire Config ### `--profile-serve` Whether to enable pyinstrument profiling on the serving endpoint.
**Default:**
`False`
### `--quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--sim-failure ` Simulate fake-perf with failure percentage ### `--speculative-method ` The speculative decoding method to use.
**Options:**
SpeculativeMethod.STANDALONE | SpeculativeMethod.EAGLE | SpeculativeMethod.MTP
### `--task ` The task to run. ### `--task-arg ` Task-specific arguments to pass to the underlying model (can be used multiple times). ### `--trust-remote-code, --no-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--use-experimental-kernels ` Enables using experimental mojo kernels with max serve. The kernels could be unstable or incorrect. ### `--use-legacy-module, --no-use-legacy-module` Whether to use the legacy Module architecture (default=True for backward compatibility). Set to False to use the new Module-based architecture when available. ### `--use-subgraphs, --no-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--use-vendor-blas ` Enables using vendor BLAS libraries (cublas/hipblas/etc) with max serve. Currently, this just replaces matmul calls. ### `--vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--weight-path ` Optional path or url of the model weights to use. ### `--zmq-endpoint-base ` Prefix for ZMQ endpoints used for IPC. This ensures unique endpoints across MAX Serve instances on the same host. Example: lora\_request\_zmq\_endpoint = f”{zmq\_endpoint\_base}-lora\_request”. --- ## max warm-cache Preloads and compiles the model to optimize initialization time by: * Pre-compiling models before deployment * Warming up the Hugging Face cache This command is useful to run before serving a model. For example: ```bash max warm-cache \ --model google/gemma-3-12b-it ``` :::note The Modular Executable Format (MEF) is platform independent, but the serialized cache (MEF files) produced during compilation is platform-dependent. This is because: * Platform-dependent optimizations happen during compilation. * Fallback operations assume a particular runtime environment. Weight transformations and hashing during MEF caching can impact performance. While efforts to improve this through weight externalization are ongoing, compiled MEF files remain platform-specific and are not generally portable. ::: ## Usage ```shell max warm-cache [OPTIONS] ``` ## Options ### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--cache-strategy ` The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
**Options:**
KVCacheStrategy.MODEL\_DEFAULT | KVCacheStrategy.PAGED
### `--ce-delay-ms ` Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler. ### `--chat-template ` Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used. ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--config-file ` ### `--custom-architectures ` Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register. ### `--data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--defer-resolve, --no-defer-resolve` Whether to defer resolving the pipeline config. ### `--device-graph-capture, --no-device-graph-capture` Enable device graph capture/replay for graph execution. ### `--device-memory-utilization ` The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size. ### `--devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast` Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models. ### `--draft-config-file ` ### `--draft-data-parallel-degree ` Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type. ### `--draft-devices ` Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu). ### `--draft-force-download, --no-draft-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--draft-huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--draft-model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--draft-quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--draft-rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--draft-section-name ` ### `--draft-served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--draft-trust-remote-code, --no-draft-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--draft-use-subgraphs, --no-draft-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--draft-vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--draft-weight-path ` Optional path or url of the model weights to use. ### `--enable-chunked-prefill, --no-enable-chunked-prefill` Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens. ### `--enable-echo, --no-enable-echo` Whether the model should be built with echo capabilities. ### `--enable-in-flight-batching, --no-enable-in-flight-batching` When enabled, prioritizes token generation by batching it with context encoding requests. ### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host` Whether to swap paged KVCache blocks to host memory when device blocks are evicted. ### `--enable-lora, --no-enable-lora` Enables LoRA on the server. ### `--enable-min-tokens, --no-enable-min-tokens` Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached. ### `--enable-overlap-scheduler, --no-enable-overlap-scheduler` Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force. ### `--enable-penalties, --no-enable-penalties` Whether to apply frequency and presence penalties to the model’s output. ### `--enable-prefix-caching, --no-enable-prefix-caching` Whether to enable prefix caching for the paged KVCache. ### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode` When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler. ### `--enable-structured-output, --no-enable-structured-output` Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to. ### `--enable-variable-logits, --no-enable-variable-logits` Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes. ### `--ep-size ` The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes. ### `--execute-empty-batches, --no-execute-empty-batches` Whether the scheduler should execute empty batches. ### `--force, --no-force` Skip validation of user provided flags against the architecture’s required arguments. ### `--force-download, --no-force-download` Whether to force download a given file if it’s already present in the local cache. ### `--gpu-profiling ` Whether to enable GPU profiling of the model.
**Options:**
GPUProfilingMode.OFF | GPUProfilingMode.ON | GPUProfilingMode.DETAILED
### `--host-kvcache-swap-space-gb ` The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled. ### `--huggingface-model-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--huggingface-weight-revision ` Branch or Git revision of Hugging Face model repository to use. ### `--kv-cache-format ` Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn. ### `--kv-cache-page-size ` The number of tokens in a single page in the paged KVCache. ### `--kvcache-ce-watermark ` Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions. ### `--lora-paths ` List of statically defined LoRA paths. ### `--max-batch-input-tokens ` The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. ### `--max-batch-size ` Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity. ### `--max-batch-total-tokens ` Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited. ### `--max-length ` Maximum sequence length of the model. ### `--max-lora-rank ` Maximum rank of all possible LoRAs. ### `--max-num-loras ` The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage. ### `--max-num-steps ` The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models). ### `--max-queue-size-tg ` Maximum number of requests in decode queue. By default, this is max\_batch\_size. ### `--min-batch-size-tg ` Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler. ### `--model-path ` The repository ID of a Hugging Face model to use. The –model option also works as an alias. ### `--num-speculative-tokens ` The number of speculative tokens. ### `--pipeline-role ` Whether the pipeline should serve both a prefill or decode role or both.
**Options:**
PipelineRole.PrefillAndDecode | PipelineRole.PrefillOnly | PipelineRole.DecodeOnly
### `--pool-embeddings, --no-pool-embeddings` Whether to pool embedding outputs. ### `--quantization-encoding ` Weight encoding type.
**Options:**
SupportedEncoding.float32 | SupportedEncoding.bfloat16 | SupportedEncoding.q4\_k | SupportedEncoding.q4\_0 | SupportedEncoding.q6\_k | SupportedEncoding.float8\_e4m3fn | SupportedEncoding.float4\_e2m1fnx2 | SupportedEncoding.gptq
### `--rope-type ` Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
**Options:**
RopeType.none | RopeType.normal | RopeType.neox | RopeType.longrope | RopeType.yarn
### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--section-name ` ### `--served-model-name ` Optional override for client-facing model name. Defaults to model\_path. ### `--speculative-method ` The speculative decoding method to use.
**Options:**
SpeculativeMethod.STANDALONE | SpeculativeMethod.EAGLE | SpeculativeMethod.MTP
### `--target ` Target API and architecture to compile for (e.g., cuda, cuda:sm\_90, hip:gfx942). When specified, uses virtual devices for compilation without requiring physical hardware. ### `--trust-remote-code, --no-trust-remote-code` Whether or not to allow for custom modelling files on Hugging Face. ### `--use-experimental-kernels ` Enables using experimental mojo kernels with max serve. The kernels could be unstable or incorrect. ### `--use-legacy-module, --no-use-legacy-module` Whether to use the legacy Module architecture (default=True for backward compatibility). Set to False to use the new Module-based architecture when available. ### `--use-subgraphs, --no-use-subgraphs` Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true. ### `--use-vendor-blas ` Enables using vendor BLAS libraries (cublas/hipblas/etc) with max serve. Currently, this just replaces matmul calls. ### `--vision-config-overrides ` Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}. ### `--weight-path ` Optional path or url of the model weights to use. ### `--zmq-endpoint-base ` Prefix for ZMQ endpoints used for IPC. This ensures unique endpoints across MAX Serve instances on the same host. Example: lora\_request\_zmq\_endpoint = f”{zmq\_endpoint\_base}-lora\_request”. --- ## Using AI coding assistants This page describes how you can accelerate your MAX and Mojo development with popular AI coding assistants (Cursor, Claude Code, GitHub Copilot, Windsurf, and others) using the following resources: - **Documentation files** (`llms.txt`): Raw text files that AI tools use for website navigation hints and direct access to up-to-date API documentation. - **Project context files** (`CLAUDE.md` and `AGENTS.md`): Markdown files in your project that provide AI assistants with build commands, coding conventions, and project structure. - **Cursor project rules** (`.cursor/rules/`): Modular-specific rules that activate based on the files you're editing, providing consistent guidance for MAX and Mojo development. ## Giving our docs to your AI assistant To improve code quality and accuracy from your AI assistant, add Modular's documentation to your AI tool's context. We provide different `llms.txt` files that are optimized for different use cases: - [`llms.txt`](/llms.txt) for a list of links with brief descriptions - [`llms-full.txt`](/llms-full.txt) for comprehensive documentation (requires large context window) - [`llms-python.txt`](/llms-python.txt) for MAX Python APIs, pipelines, or graph construction - [`llms-mojo.txt`](/llms-mojo.txt) for Mojo code, GPU kernels, or MAX library APIs If you're in Cursor, use the `@` syntax in the chat window to reference documentation directly. For example, to insert the MAX Python API documentation, add this to your chat: ```text @docs.modular.com/llms-python.txt ``` :::tip Cursor tip To make the docs a permanent knowledge resource, type `@docs.new` in the chat and hit Enter (select **Add new doc**), then paste the `llms.txt` URL such as `docs.modular.com/llms-python.txt`. Edit the name and click **Confirm**. Read more about the [Cursor @Docs command](https://cursor.com/docs/context/mentions#docs). ::: For all other tools (Claude Code, Copilot, Windsurf, and others), just tell the AI to fetch the documentation: ```text Read https://docs.modular.com/llms-python.txt for MAX Python API documentation ``` For more details, see the section below about [our `llms.txt` files](#modular-llmstxt-files). ## Working in the Modular repository When you clone the [Modular repo](https://github.com/modular/modular), your AI tool will automatically discover project context (either `CLAUDE.md` or `AGENTS.md`) files that provide: - Build and test commands - Project architecture overview - Coding conventions and commit message guidelines - Common development workflows Different tools look for different files (either `CLAUDE.md` or `AGENTS.md`, which are identical)—for compatibility details, see the section below about [Modular context files](#modular-context-files). **Cursor users**: The repository also includes project rules in `.cursor/rules/` that activate based on the files you're editing. These rules provide file-specific guidance for Mojo and MAX development. See [Modular Cursor rules](#modular-cursor-rules) for the full list. ## Starting a new project For your own MAX and Mojo projects, the entirety of our [`CLAUDE.md`](https://github.com/modular/modular/blob/main/CLAUDE.md) and [`AGENTS.md`](https://github.com/modular/modular/blob/main/AGENTS.md) files are probably not useful for your own projects because they're specific to the Modular repository. However, you might find some parts useful, such as the section about [Critical development notes](https://github.com/modular/modular/blob/main/CLAUDE.md#critical-development-notes), which you can copy to your own `CLAUDE.md` or `AGENTS.md` file. **Cursor users**: You also might want to copy the [Modular Cursor rules](#modular-cursor-rules) below to your `.cursor/rules/` directory. ## Resources for AI assistants Here are all the resources we provide for AI assistants. ### Modular `llms.txt` files Our documentation supports the [llms.txt](https://llmstxt.org/) proposed standard, providing LLM-friendly documentation files: | File | Description | Best for | |------|-------------|----------| | llms.txt | Index of links with brief descriptions | Tools that can follow links | | llms-full.txt | All docs in a single file (huge) | Comprehensive context | | llms-python.txt | MAX Python APIs | Python-based MAX development | | llms-mojo.txt | Mojo stdlib, MAX AI Kernels, MAX library | Mojo and kernel development | ### Modular Cursor rules Copy these rules to your `.cursor/rules/` directory for consistent AI assistance with Mojo and MAX development: - **[`general_behavior_rules.mdc`](https://github.com/modular/modular/blob/main/.cursor/rules/general_behavior_rules.mdc)**: General rules for code creation. Emphasizes simplicity, thorough investigation, using existing solutions, descriptive naming, robust error handling, and documentation. - **[`git.mdc`](https://github.com/modular/modular/blob/main/.cursor/rules/git.mdc)**: Best practices for Git including code organization, commit strategies, branching models, and collaborative workflows. - **[`max_development.mdc`](https://github.com/modular/modular/blob/main/.cursor/rules/max_development.mdc)**: Guidelines for MAX development covering project context, task execution, code quality, testing practices, and documentation. - **[`mojo.mdc`](https://github.com/modular/modular/blob/main/.cursor/rules/mojo.mdc)**: Mojo coding standards, performance optimizations, and best practices for GPU-accelerated code. ### Modular context files Project context files (`CLAUDE.md` and `AGENTS.md`) provide AI coding assistants with essential project information, such as: - Build and test commands - Project architecture overview - Coding conventions and commit message guidelines - Common development workflows Although `CLAUDE.md` is primarily used for Claude Code, some other AI tools also read it, but perhaps not by default when working in the repository—you might need to explicitly tell your AI to read it. The [`AGENTS.md` filename](https://agents.md/) is a more generic name that other coding assistants have adopted. To ensure compatibility with other tools, we provide both `CLAUDE.md` and `AGENTS.md` files—our [`AGENTS.md`](https://github.com/modular/modular/blob/main/AGENTS.md) file is just a symlink to the [`CLAUDE.md`](https://github.com/modular/modular/blob/main/CLAUDE.md) file. --- ## MAX container import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import Requirements from '@site/src/components/Requirements'; import { requirementsNoMacWithGPU } from '@site/docs/max/requirements'; The MAX container is our official Docker container that simplifies the process to deploy a GenAI model with an OpenAI-compatible endpoint. The container includes the latest version of MAX and it integrates with orchestration tools like Kubernetes. Alternatively, you can also experiment with MAX on a local endpoint using the [`max serve`](/max/cli/serve) command. The result is the same because the MAX container creates an isolated environment that also uses `max serve` to create an endpoint you can interact with using our OpenAI-compatible [REST API](/max/api/serve). :::note Linux only The MAX container is currently not compatible with macOS. ::: ## Get started First, make sure you're on a system with the following requirements: Then start an endpoint with the MAX container: 1. Make sure you have [Docker installed](https://docs.docker.com/get-started/get-docker/). 2. Agree to the [Gemma 3 license on Hugging Face](https://huggingface.co/google/gemma-3-27b-it) and set the `HF_TOKEN` environment variable: ```bash export HF_TOKEN="hf_..." ``` 3. Start the container and an endpoint for Gemma 3: ```bash docker run --gpus=1 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ -p 8000:8000 \ modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` ```bash docker run \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --env "HF_TOKEN=${HF_TOKEN}" \ -p 8000:8000 \ --group-add keep-groups \ --device /dev/kfd \ --device /dev/dri \ modular/max-amd:latest \ --model-path google/gemma-3-27b-it ``` It can take a few minutes to pull the container and then download and compile the model. When the endpoint is ready, you'll see a message that says this: ```output 🚀 Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` 3. Open a new terminal and send a request using the `openai` Python API or `curl`: 1. Create a new virtual environment: ```sh mkdir quickstart && cd quickstart ``` ```sh python3 -m venv .venv/quickstart \ && source .venv/quickstart/bin/activate ``` 2. Install the OpenAI Python API: ```bash pip install openai ``` 3. Create the following file to send an inference request: ```python title="generate-text.py" from openai import OpenAI client = OpenAI( base_url="http://0.0.0.0:8000/v1", api_key="EMPTY", ) completion = client.chat.completions.create( model="google/gemma-3-27b-it", messages=[ { "role": "user", "content": "Who won the world series in 2020?" }, ], ) print(completion.choices[0].message.content) ``` 4. Run it and you should see results like this: ```sh python generate-text.py ``` ```output The **Los Angeles Dodgers** won the World Series in 2020! They defeated the Tampa Bay Rays 4 games to 2. It was their first World Series title since 1988. It was a unique World Series as it was played in a neutral site (Globe Life Field in Arlington, Texas) due to the COVID-19 pandemic. ``` Run this command: ```sh curl -N http://0.0.0.0:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-3-27b-it", "stream": true, "messages": [ {"role": "user", "content": "Who won the World Series in 2020?"} ] }' | grep -o '"content":"[^"]*"' | sed 's/"content":"//g' | sed 's/"//g' | tr -d '\n' | sed 's/\\n/\n/g' ``` You should see results like this: ```output The **Los Angeles Dodgers** won the World Series in 2020! They defeated the Tampa Bay Rays 4 games to 2. It was their first World Series title since 1988. It was a unique World Series as it was played in a neutral site (Globe Life Field in Arlington, Texas) due to the COVID-19 pandemic. ``` For details about the OpenAI-compatible endpoint, see [our Serve API docs](/max/api/serve). To run a different model, change the `--model-path` to something else from [our model repository](https://builds.modular.com/?category=models). For information about the available containers, see the [Modular Docker Hub repositories](https://hub.docker.com/r/modular). ## Container options The `docker run` command above includes the bare minimum commands and options, but there are other `docker` options you might consider, plus several options to control features of the endpoint. ### Docker options - `--gpus`: If your system includes a compatible NVIDIA GPU, you must add the [`--gpus` option](https://docs.docker.com/reference/cli/docker/container/run/#gpus) in order for the container to access it. It doesn't hurt to include this even if your system doesn't have a [GPU compatible with MAX](/max/packages#gpu-compatibility). - `--devices`: When deploying MAX on multiple GPUs, you must specify the ID of the GPUs to use. For example, to use four available GPUs, you should include the following: `--devices gpu:0,1,2,3`. When you don't specify a `--devices` option, MAX defaults to using the first available GPU it discovers (equivalent to `--devices gpu:0`). You can also optionally specify `--devices cpu`. - `-v`: We use the [`-v` option](https://docs.docker.com/reference/cli/docker/container/run/#volume) to save a cache of Hugging Face and MAX models to your local disk that we can reuse across containers. You can optionally export a `MODULAR_MAX_CACHE_DIR` environment variable to change the MAX cache directory location. - `-p`: We use the [`-p` option](https://docs.docker.com/reference/cli/docker/container/run/#publish) to specify the exposed port for the endpoint. You also might need some environment variables (set with `--env`): - `HF_TOKEN`: This is required to access gated models on Hugging Face (after your account is granted access). For example: ```sh docker run \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ # highlight-start --env "HF_TOKEN=${HF_TOKEN}" \ # highlight-end -p 8000:8000 \ modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` Learn more about [`HF_TOKEN`](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hftoken) and how to create [Hugging Face access tokens](https://huggingface.co/docs/hub/en/security-tokens). ### MAX options Following the container name in the `docker run` command, you must specify a model with `--model-path`, but there are other options you might need to configure the `max serve` behavior. To see all available options, see the [`max` CLI page](/max/cli/serve), because the MAX container is basically a wrapper around that tool. - `--model-path`: This is required to specify the model you want to deploy. To find other GenAI models that are compatible with MAX, check out our [list of models on MAX Builds](https://builds.modular.com/?category=models). - `--max-length`: Specifies the maximum length of the text sequence (including the input tokens). We mention this one here because it's often necessary to adjust the max length when you have trouble running a large model on a machine with limited memory. For the rest of the `max serve` options, see the [`max` CLI page](/max/cli/serve). ## Container contents There are multiple MAX container options, including: - [`max-full`](/max/container#full-container) - [`max-amd`](/max/container#amd-container) - [`max-amd-base`](/max/container#amd-container) - [`max-nvidia-full`](/max/container#nvidia-container) - [`max-nvidia-base`](/max/container#nvidia-container) ### Full container The full MAX container (`max-full`) is a hardware-agnostic container that's built to deploy the latest version of MAX on both AMD and NVIDIA GPUs. You can run the container on either NVIDIA or AMD as follows: ```bash docker run --gpus=1 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --env "HF_TOKEN=${HF_TOKEN}" \ -p 8000:8000 \ modular/max-full:latest \ --model-path google/gemma-3-27b-it ``` ```bash docker run \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --env "HF_TOKEN=${HF_TOKEN}" \ --group-add keep-groups \ --device /dev/kfd \ --device /dev/dri \ -p 8000:8000 \ modular/max-full:latest \ --model-path google/gemma-3-27b-it ``` The `max-full` container includes the following: - Ubuntu 22.04 - Python 3.12 - MAX 25.4 - PyTorch (GPU) 2.6.0 - ROCm - cuDNN - CUDA 12.8 - NumPy - Hugging Face Transformers For more information, see the [full MAX container on Docker Hub](https://hub.docker.com/r/modular/max-full). ### AMD container The AMD MAX container (`max-amd`) is great if you want an AMD-specific deployment without NVIDIA or CUDA dependencies. The AMD MAX container is available in two flavors: - `max-amd` includes all ROCm and PyTorch GPU dependencies - `max-amd-base` includes minimal dependencies, ROCm, and the AMD Driver You can run the AMD container as follows: ```bash docker run \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --env "HF_TOKEN=${HF_TOKEN}" \ --group-add keep-groups \ --device /dev/kfd \ --device /dev/dri \ -p 8000:8000 \ modular/max-amd:latest \ --model-path google/gemma-3-27b-it ``` Or, to use the base container, replace `max-amd` with `max-amd-base`. For more information, see the [full AMD container](https://hub.docker.com/r/modular/max-amd) or [base AMD container](https://hub.docker.com/r/modular/max-amd-base) on Docker Hub. ### NVIDIA container The NVIDIA MAX container is available in two flavors: - `max-nvidia-full` includes all CUDA and PyTorch GPU dependencies - `max-nvidia-base` includes minimal dependencies, PyTorch CPU, and the NVIDIA Driver (excludes CUDA) You can run the NVIDIA container as follows: ```bash docker run --gpus=1 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --env "HF_TOKEN=${HF_TOKEN}" \ -p 8000:8000 \ modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` Or, to use the base container, replace `max-nvidia-full` with `max-nvidia-base`. For more information, see the [full NVIDIA container](https://hub.docker.com/r/modular/max-nvidia-full) or [base NVIDIA container](https://hub.docker.com/r/modular/max-nvidia-base) on Docker Hub. ## Recommended cloud instances For best performance and compatibility with the [available models on MAX Builds](https://builds.modular.com/?category=models), we recommend that you deploy the MAX container on a cloud instance with a GPU that meets the [MAX system requirements](/max/packages#system-requirements). The Modular Platform is hardware-agnostic and optimized for both the latest NVIDIA and AMD GPUs. To take full advantage of this flexibility, Modular partners with [compute providers](https://www.modular.com/customers) that prioritize diverse hardware optionality. For enterprise-grade hardware flexibility, see our available [editions](https://www.modular.com/pricing). If you're running on AWS, GCP, or Azure and want to test MAX with cloud GPUs, we recommend the following instances: AWS instances: - [P6](https://aws.amazon.com/ec2/instance-types/p6/) instance family (B200 GPU) - [P5](https://aws.amazon.com/ec2/instance-types/p5/) instance family (H100 GPU) - [P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance family (A100 GPU) GCP instances: - [A4](https://cloud.google.com/compute/docs/gpus#b200-gpus) machine series (B200 GPU) - [A3](https://cloud.google.com/compute/docs/gpus#a3-series) machine series (H100 GPU) - [A2](https://cloud.google.com/compute/docs/gpus#a100-gpus) machine series (A100 GPU) Azure instances: - [ND_GB200_v6-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-gb200-v6-series) virtual machine - [NCads_H100_v5-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nc-family#ncads_h100_v5-series) virtual machine - [NCCads_H100_v5-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nc-family#nccads_h100_v5-series) virtual machine - [ND_H100_v5-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-family#nd_h100_v5-series) virtual machine - [NC_A100_v4-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nc-family#nc_a100_v4-series) virtual machine - [NDm_A100_v4-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-family#ndm_a100_v4-series) virtual machine - [ND_A100_v4-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-family#nd_a100_v4-series) virtual machine - [ND_MI300X_v5-series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-family#nd_mi300x_v5-series) virtual machine (AMD GPU) ## Logs The MAX container writes logs to stdout in JSON format, which you can consume and view via your cloud provider's platform (for example, [with AWS CloudWatch](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html)). Console log level is `INFO` by default. You can modify the log level using the `MAX_SERVE_LOGS_CONSOLE_LEVEL` environment variable. It accepts the following log levels (in order of increasing verbosity): `CRITICAL`, `ERROR`, `WARNING`, `INFO`, `DEBUG`. For example: ```bash docker run modular/max-nvidia-full:latest \ -env MAX_SERVE_LOGS_CONSOLE_LEVEL=DEBUG \ ... ``` Logs default to structured JSON, but if you'd like a more readable format in your console, you can disable structured logs by adding the `MODULAR_STRUCTURED_LOGGING=0` environment variable. For example: ```bash docker run --gpus=1 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ -p 8000:8000 \ # highlight-start --env "MODULAR_STRUCTURED_LOGGING=0" \ # highlight-end modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` ## Metrics The MAX container exposes a `/metrics` endpoint that follows the [Prometheus](https://prometheus.io/docs/introduction/overview/) text format. You can scrape the metrics listed below using Prometheus or another collection service. These are raw metrics and it's up to you to compute the desired time series and aggregations. For example, we provide a count for output tokens (`maxserve_num_output_tokens_total`), which you can use to calculate the output tokens per second (OTP/s). Here are all the available metrics: - `maxserve_request_time_milliseconds`: Histogram of time spent handling each request (total inference time, or TIT), in milliseconds. - `maxserve_input_processing_time_milliseconds`: Histogram of input processing time (IPT), in milliseconds. - `maxserve_output_processing_time_milliseconds`: Histogram of output generation time (OGT), in milliseconds. - `maxserve_time_to_first_token_milliseconds`: Histogram of time to first token (TTFT), in milliseconds. - `maxserve_num_input_tokens_total`: Total number of input tokens processed so far. - `maxserve_num_output_tokens_total`: Total number of output tokens processed so far. - `maxserve_input_tokens_per_request`: Histogram of input tokens per request. - `maxserve_output_tokens_per_request`: Histogram of output tokens per request. - `maxserve_request_count_total`: Total requests since start. - `maxserve_num_requests_running`: Number of requests currently running. ### Telemetry In addition to sharing these metrics via the `/metrics` endpoint, the MAX container actively sends the metrics to Modular via push telemetry (using OpenTelemetry). :::note None of the telemetry includes personally identifiable information (PII). ::: This telemetry is anonymous and helps us quickly identify problems and build better products for you. Without this telemetry, we would rely solely on user-submitted bug reports, which are limited and would severely limit our performance insights. However, if you don't want to share this data with Modular, you can disable telemetry in your container. To disable telemetry, enable the `MAX_SERVE_DISABLE_TELEMETRY` environment variable when you start your MAX container. For example: ```bash docker run modular/max-nvidia-full:latest \ -env MAX_SERVE_DISABLE_TELEMETRY=1 \ ... ``` #### Deployment and user ID Again, the telemetry is completely anonymous by default. But if you'd like to share some information to help our team assist you in understanding your deployment performance, you can add some identity information to the telemetry with these environment variables: - `MAX_SERVE_DEPLOYMENT_ID`: Your application name. - `MODULAR_USER_ID`: Your company name. For example: ```bash docker run modular/max-nvidia-full:latest \ -env MAX_SERVE_DEPLOYMENT_ID='Project name' \ -env MODULAR_USER_ID='Example Inc.' \ ... ``` ## License The NVIDIA MAX container is released under the [NVIDIA Deep Learning Container license](https://developer.download.nvidia.com/licenses/NVIDIA_Deep_Learning_Container_License.pdf). ## Next steps export const docs = [ '../deploy/local-to-cloud.mdx', '../model-formats.mdx', '../serve/index.mdx', ]; --- ## Benchmark MAX on NVIDIA or AMD GPUs import InstallModular from '@site/docs/_includes/install-modular.mdx'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import Requirements from '@site/src/components/Requirements'; import { requirementsWithDockerAndGPU } from '@site/docs/max/requirements'; In this tutorial, you'll deploy a MAX inference endpoint using our [GPU-enabled containers](/max/container) and evaluate endpoint performance with the [`max benchmark`](/max/cli/benchmark) CLI command. You'll collect key metrics such as throughput, latency, and token-processing speed to benchmark under production-like workloads and establish a baseline before scaling or integrating MAX into your deployments. :::caution Datacenter GPU required For the best performance, use an **NVIDIA B200 / H200 / H100** or **AMD MI355X / MI325X / MI300X**. MAX can serve models on a wide range of CPUs and GPUs, but the big LLMs most customers want require an amount of memory that's only available on the latest datacenter GPUs. Specifically, this tutorial uses the [Gemma 3 27B model](https://builds.modular.com/models/gemma-3-it/27B), which requires at least 60 GiB of memory. ::: Deploying AI inference workloads requires careful performance optimization, balancing factors such as accuracy, latency, and cost. This tutorial benchmarks a [Gemma 3](https://builds.modular.com/models/gemma-3-it/27B) endpoint using the `max benchmark` CLI command, which provides key metrics to evaluate the model server's performance, including: - Request throughput - Input and output token throughput - Time-to-first-token (TTFT) - Time per output token (TPOT) Our benchmarking script is adapted from vLLM with additional features, such as client-side GPU metric collection to ensure consistent and comprehensive performance measurement that's tailored to MAX. You can see the [benchmark script source here](https://github.com/modular/modular/tree/main/max/python/max/benchmark). System requirements: :::note This tutorial is intended for production environments using Docker, which can be difficult to set up with GPU access on some systems. If you have any trouble with Docker, you can instead run benchmarks on an endpoint created with the `max serve` command—for instructions, see the [quickstart guide](/max/get-started). ::: ## Get access to the model From here on, you should be running commands on the system with your GPU. If you haven't already, open a shell to that system now. You'll first need to authorize your Hugging Face account to access the Gemma model: 1. Obtain a [Hugging Face access token](https://huggingface.co/settings/tokens) and set it as an environment variable: ```bash export HF_TOKEN="hf_..." ``` 2. Agree to the [Gemma 3 license on Hugging Face](https://huggingface.co/google/gemma-3-27b-it). ## Start the model endpoint We provide a pre-configured GPU-enabled Docker container that simplifies the process to deploy an endpoint with MAX. For more information, see [MAX container](/max/container). Use this command to pull the MAX container and start the model endpoint: ```bash docker run --rm --gpus=all \ --ipc=host \ -p 8000:8000 \ --env "HF_TOKEN=${HF_TOKEN}" \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` ```bash docker run \ --device /dev/kfd \ --device /dev/dri \ --group-add video \ --ipc=host \ -p 8000:8000 \ --env "HF_TOKEN=${HF_TOKEN}" \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ modular/max-amd:latest \ --model-path google/gemma-3-27b-it ``` If you want to try a different model, see our [model repository](https://builds.modular.com/?category=models). The server is running when you see the following terminal message (beware Docker prints [JSON logs by default](/max/container#logs)): ```output 🚀 Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` ## Start benchmarking Open a second terminal and install the `modular` package to get the `max` CLI tool we'll use to perform benchmarking. ### Set up your environment ### Benchmark the model To benchmark MAX with the `sonnet` dataset, use this command: ```bash max benchmark \ --model google/gemma-3-27b-it \ --backend modular \ --endpoint /v1/chat/completions \ --dataset-name sonnet \ --num-prompts 500 \ --sonnet-input-len 550 \ --output-lengths 256 \ --sonnet-prefix-len 200 ``` :::note By default, this sends requests to `localhost:8000`, but you can override with the `--host` and `--port` arguments. In order to download the dataset, you must have permission to write to your Hugging Face cache. You might need to change permissions with `chown`. ::: When you want to save your own benchmark configurations, you can define them in a YAML file and pass it to the `--config-file` option. For example, copy our [`gemma-3-27b-sonnet-decode-heavy-prefix200.yaml`](https://github.com/modular/modular/tree/main/max/python/max/benchmark/configs/gemma-3-27b-sonnet-decode-heavy-prefix200.yaml) file from GitHub, and you can benchmark the same model with this command: ```sh max benchmark --config-file gemma-3-27b-sonnet-decode-heavy-prefix200.yaml ``` For more information, including other datasets and configuration options, see the [`max benchmark` documentation](/max/cli/benchmark). ### Use your own dataset The command above uses the `sonnet` dataset from Hugging Face, but you can also provide a path to your own dataset. For example, you can download the [ShareGPT](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) dataset with this command: ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` You can then use the local dataset with the `--dataset-path` argument: ```bash max benchmark \ ... --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ ``` ## Interpret the results Of course, your results depend on your hardware, but the structure of the output should look like this: ```output ============ Serving Benchmark Result ============ Successful requests: 50 Failed requests: 0 Benchmark duration (s): 25.27 Total input tokens: 12415 Total generated tokens: 11010 Total nonempty serving response chunks: 11010 Input request rate (req/s): inf Request throughput (req/s): 1.97837 ------------Client Experience Metrics------------- Max Concurrency: 50 Mean input token throughput (tok/s): 282.37 Std input token throughput (tok/s): 304.38 Median input token throughput (tok/s): 140.81 P90 input token throughput (tok/s): 9.76 P95 input token throughput (tok/s): 7.44 P99 input token throughput (tok/s): 4.94 Mean output token throughput (tok/s): 27.31 Std output token throughput (tok/s): 8.08 Median output token throughput (tok/s): 30.64 P90 output token throughput (tok/s): 12.84 P95 output token throughput (tok/s): 9.11 P99 output token throughput (tok/s): 4.71 ---------------Time to First Token---------------- Mean TTFT (ms): 860.54 Std TTFT (ms): 228.57 Median TTFT (ms): 809.41 P90 TTFT (ms): 1214.68 P95 TTFT (ms): 1215.34 P99 TTFT (ms): 1215.82 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 46.72 Std TPOT (ms): 39.77 Median TPOT (ms): 32.63 P90 TPOT (ms): 78.24 P95 TPOT (ms): 111.87 P99 TPOT (ms): 216.31 ---------------Inter-token Latency---------------- Mean ITL (ms): 31.16 Std ITL (ms): 91.79 Median ITL (ms): 1.04 P90 ITL (ms): 176.93 P95 ITL (ms): 272.52 P99 ITL (ms): 276.72 -------------Per-Request E2E Latency-------------- Mean Request Latency (ms): 7694.01 Std Request Latency (ms): 6284.40 Median Request Latency (ms): 5667.19 P90 Request Latency (ms): 16636.07 P95 Request Latency (ms): 21380.10 P99 Request Latency (ms): 25251.18 ``` For more information about each metric, see the [MAX benchmarking key metrics](https://github.com/modular/modular/tree/main/max/python/max/benchmark#key-metrics-explained). ### Measure latency with finite request rates Latency metrics like time-to-first-token (TTFT) and time per output token (TPOT) matter most when the server isn't overloaded. An overloaded server will queue requests, which results in a massive increase in latency that varies depending on the size of the benchmark more than the actual latency of the server. Benchmarks with a larger number of prompts result in a deeper queue. If you'd like to vary the size of the queue, you can adjust the request rate with the `--request-rate` flag. This creates a stochastic request load with an average rate of `N` requests per second. ### Comparing to alternatives You can run the benchmarking script using the Modular or vLLM backends to compare performance with alternative LLM serving frameworks. Before running the benchmark, make sure you set up and launch the corresponding inference engine so the script can send requests to it. :::tip Optional cleanup When you're done benchmarking, you can clean up the Docker image with the following command: ```bash docker rmi $(docker images -q modular/max-nvidia-full:latest) ``` ```bash docker rmi $(docker images -q modular/max-amd:latest) ``` ::: ## Next steps Now that you have detailed benchmarking results for Gemma 3 on MAX using an NVIDIA or AMD GPU, you can explore more advanced scaling optimizations: export const docs = [ '../../../mammoth/index.mdx', '../../../mammoth/orchestrator.mdx', '../../../mammoth/disaggregated-inference.mdx' ] To read more about our performance methodology, check our blog post, [MAX GPU: State of the Art Throughput on a New GenAI platform](https://www.modular.com/blog/max-gpu-state-of-the-art-throughput-on-a-new-genai-platform). You can also share your experience on the [Modular Forum](https://forum.modular.com/) and in our [Discord Community](https://discord.gg/modular). --- ## Deploy MAX on GPU in the Cloud import SmallCards from '@site/src/components/SmallCards'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import Requirements from '@site/src/components/Requirements'; import { requirementsWithGPU } from '@site/docs/max/requirements'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; In this tutorial, you'll deploy a MAX inference endpoint with Llama 3 from local testing to production on AWS, GCP, or Azure. You'll learn to serve models with an OpenAI-compatible endpoint, automate deployment using Infrastructure-as-Code templates, and optimize performance with GPU resources—establishing the foundation for production-ready LLM deployments. MAX provides a streamlined way to deploy large language models (LLMs) with production-ready features like GPU acceleration, automatic scaling, and monitoring capabilities. Whether you're building a prototype or preparing for production deployment, this tutorial will help you set up a robust serving infrastructure for Llama 3. And although we're using Llama 3 in these instructions, you can swap it for one of the hundreds of other LLMs from Hugging Face by browsing [our model repository](https://builds.modular.com/?category=models&modality=Chat). The tutorial is organized into the following sections: - **[Local setup](#local-setup)**: Run Llama 3 locally to verify its basic functionality. - **[Cloud deployment](#cloud-deployment)**: Deploy Llama 3 to AWS, GCP, or Azure using IaC templates and CLI commands. System requirements: ## Local setup In this section, you will set up and run Llama 3 locally to understand its capabilities and validate functionality before moving to the cloud. This part doesn't require a GPU because MAX can also run Llama 3 on CPUs, but we recommend using a [compatible GPU](/max/faq/#gpu-requirements) for the best performance. ### 1. Set up your environment Create a Python project to install our APIs and CLI tools: ### 2. Serve Llama 3 locally Next, use the `max` CLI tool to start an endpoint with the Llama 3 model locally, and ensure that the model runs as expected before deploying it in the cloud. :::note If you want to try a different model, swap the `modularai/Llama-3.1-8B-Instruct-GGUF` name in all the commands to another Hugging Face model ID from [our model repository](https://builds.modular.com/?category=models&modality=Chat). Just be aware that some Hugging Face models require access approval and might have different memory requirements. ::: 1. Generate a response to a prompt with the following command: ```bash max generate --model modularai/Llama-3.1-8B-Instruct-GGUF \ --prompt "What is the meaning of life?" \ --max-length 250 ``` 2. Start the model server using `max serve`: ```bash max serve --model modularai/Llama-3.1-8B-Instruct-GGUF ``` This starts a local endpoint with an OpenAI-compatible endpoint. Next, we'll send it an inference request. :::note GPU-enabled Docker containers We also provide a pre-configured GPU-enabled Docker container that simplifies deployment. We'll use the MAX container in the [cloud deployment](#cloud-deployment) steps. ::: ### 3. Test the local endpoint The endpoint is ready when you see this message in the terminal: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Then, you can test its functionality by sending a `curl` request from a new terminal: ```bash curl -N http://0.0.0.0:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "modularai/Llama-3.1-8B-Instruct-GGUF", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the World Series in 2020?"} ] }' | jq -r '.choices[].message.content' ``` You should see output like this: ```output The Los Angeles Dodgers won the 2020 World Series. They defeated the Tampa Bay Rays in the series 4 games to 2. This was the Dodgers' first World Series title since 1988. ``` To learn more about the supported REST body parameters, see our [API reference for chat completion](/max/api/serve#operation/createChatCompletion). Now that the model works locally, we'll transition to cloud deployment. ## Cloud deployment paths {#cloud-deployment} We will use Infrastructure-as-Code (IaC) to create, configure, and deploy Llama 3 in the cloud. The cloud deployment instructions are divided by provider: AWS, GCP, and Azure. ### Cloud deployment overview For AWS, we will use CloudFormation, for GCP, we will use Deployment Manager, and for Azure, we will use Resource Manager. These IaC templates handle resource provisioning, networking, and security configuration. This approach simplifies deployments and ensures they are repeatable. The key steps are: - **Create and Deploy Stack/Resources**: Use IaC templates for each cloud provider to deploy Llama 3. - **Test the Endpoint**: Retrieve the public IP address after deployment and send a request to test the Llama 3 endpoint in the cloud. Each cloud-specific tab provides complete commands for setup, configuration, deployment, and testing. To better understand the flow of the deployment, here is a high-level overview of the architecture:
Figure 1. Architecture diagram of the cloud stack for deploying MAX.
This architecture diagram illustrates the two-phase deployment setup for serving the Llama 3 model with MAX on cloud provider infrastructure. The deployment process is divided into two phases: * **Phase 1: Cloud stack creation**: In this initial phase, the following infrastructure is provisioned and configured to prepare for serving requests: * **Public IP assignment**: The cloud provider assigns a public IP to the virtual machine (VM), allowing it to be accessed externally. * **Firewall/Security group configuration**: Security settings, such as firewall rules or security groups, are applied to allow traffic on port 80. This setup ensures that only HTTP requests can access the instance securely. * **GPU compute instance setup**: A GPU-enabled VM is created to handle model inference efficiently. This instance includes: * **GPU drivers/runtime installation**: Necessary GPU drivers and runtime libraries are installed to enable hardware acceleration for model processing. * **Docker container initialization**: A Docker container is launched on the VM, where it pulls the necessary images from the Docker Container Registry. This registry serves as a central repository for storing Docker images, making it easy to deploy and update the application. Inside the container, MAX is set up alongside the Llama 3 model. This setup prepares the environment for serving requests but does not yet expose the endpoint to users. :::note GPU-enabled Docker containers The pre-configured GPU-enabled Docker container includes all necessary dependencies and configurations for running Llama 3 with GPU acceleration. The provided IaC templates initialize the MAX container. If you don't use the provided templates for infrastructure set up, you can initialize the container image with the `docker run` command. For more information, see [MAX container](/max/container). ::: * **Phase 2: Serving the user endpoint**: Once the cloud stack is configured and the VM is set up, the deployment enters the second phase, where it starts serving user requests: * **HTTP endpoint exposure**: With the VM and Docker container ready, the system opens an OpenAI compatible HTTP endpoint on port 80, allowing users to interact with the deployed Llama 3 model. * **Request handling by MAX**: When a user sends an HTTP request to the public IP, MAX processes the incoming request within the Docker container and forwards it to the Llama 3 model for inference. The model generates a response, which is then returned to the user via the endpoint. :::caution For the sake of this tutorial, we expose the public IP address of the VM to the internet. This is not recommended for direct use in production environments as it may expose your model to security risks. ::: ### Prerequisites Be sure that you have the following prerequisites, as well as appropriate access and permissions for the cloud provider of your choice. - **GPU resources**: You'll need access to GPU resources in your cloud account with the following specifications: - **Minimum GPU memory**: 24GB - **Supported GPU types**: [See our compatible GPUs](/max/packages#gpu-compatibility) This tutorial has been tested on the following NVIDIA instances: `g5.4xlarge` (A10G) on AWS, `g2-standard-8` (L4) on GCP, and `Standard_NV36ads_A10_v5` (A10G) on Azure. It has also been tested on the AMD `Standard_ND96isr_MI300X_v5` (MI300X) Azure instance. You can alter the provided cloud config files to deploy MAX on any [compatible cloud instance or virtual machine](/max/container#recommended-cloud-instances). - **A Hugging Face user access token**: A valid Hugging Face token is required to access the model. To create a Hugging Face user access token, see [Access Tokens](https://huggingface.co/settings/tokens). You must make your token available in your environment with the following command: ```bash export HF_TOKEN="hf_..." ``` - **Docker installation**: Install the [Docker Engine and CLI](https://docs.docker.com/engine/install/). We use a pre-configured GPU-enabled Docker container from our public repository. For more information, check out all of our [available containers](/max/container#container-contents). - **Cloud CLI tools**: Before deploying, ensure that you have the respective cloud provider CLI tools installed. - [AWS CLI v2](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) installed and configured with appropriate credentials - [Google Cloud SDK](https://cloud.google.com/sdk/docs/install) installed and initialized - [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli) installed, logged in, and configured Configure the AWS CLI: ```bash aws configure ``` Log in to your AWS account: ```bash aws sso login ``` Check the credentials via `cat ~/.aws/credentials` to make sure it is set up correctly. You can also include the credentials as environment variables: ```bash export AWS_ACCESS_KEY_ID="YOUR_ACCESS_KEY_ID" export AWS_SECRET_ACCESS_KEY="YOUR_SECRET_ACCESS_KEY" ``` Initialize the Google Cloud SDK: ```bash gcloud init ``` Log in to your Google Cloud account: ```bash gcloud auth login ``` Initialize the Azure CLI: ```bash az init ``` Log into your Azure account: ```bash az login ``` ### 1. Create stack/deployment In this section, we'll walk through creating a deployment stack on AWS, GCP, and Azure. Each cloud provider has its own configuration steps, detailed below, but we simplify the setup by using Infrastructure-as-Code (IaC) templates. Start by cloning the MAX repository and navigating to the `modular/examples/cloud-configs/` directory, where the necessary IaC templates and configuration files are organized for each cloud provider. ```bash git clone -b stable https://github.com/modular/modular && cd modular/examples/cloud-configs ``` This directory includes all files required to deploy MAX to AWS, GCP, or Azure: :::note AMD GPU cloud deployment Azure provides AMD GPU virtual machines. If you want to deploy MAX with AMD GPUs on Azure, you can use the `modular/examples/cloud-configs/azure/amd/max-amd-azure.json` config file. This file defines the appropriate image and settings for AMD-based inference workloads on an [ND MI300X v5 series](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/ndmi300xv5-series) vm. ::: ```bash max/examples/cloud-configs/ ├── aws │ ├── max-nvidia-aws.yaml │ └── notify.sh ├── azure │ ├── amd │ │ ├── max-amd-azure.json │ │ └── notify.sh │ ├── nvidia │ │ ├── max-nvidia-azure.json │ │ └── notify.sh └── gcp ├── max-nvidia-gcp.jinja └── notify.sh ``` With these IaC templates ready, choose your preferred cloud provider and follow the step-by-step instructions specific to each platform. :::note Preparing the deployment takes some time Stack creation may take some time to complete and completion times differ across cloud providers. ::: First navigate to the AWS directory: ```bash cd aws ``` Set the region in your environment: ```bash export REGION="REGION" # example: `us-east-1` ``` Then, create the stack. You can explore the `max-nvidia-aws.yaml` file for AWS CloudFormation configuration information. :::note Stack naming The stack name must be **unique** so please be sure to change the `--stack-name` if you create multiple stacks. ::: ```bash export STACK_NAME="max-serve-stack" aws cloudformation create-stack --stack-name ${STACK_NAME} \ --template-body file://max-nvidia-aws.yaml \ --parameters \ ParameterKey=InstanceType,ParameterValue=g5.4xlarge \ ParameterKey=HuggingFaceHubToken,ParameterValue=${HF_TOKEN} \ ParameterKey=HuggingFaceRepoId,ParameterValue=modularai/Llama-3.1-8B-Instruct-GGUF \ --capabilities CAPABILITY_IAM \ --region $REGION ``` :::note GCP access requirements You must have access to `deploymentmanager.googleapis.com`, `logging.googleapis.com`, `compute.googleapis.com` and be able to use `gcloud compute firewall-rules` to configure inbound traffic. ::: First, navigate to the GCP directory: ```bash cd gcp ``` Set the project ID: ```bash PROJECT_ID="YOUR PROJECT ID" export ZONE="ZONE" # example `us-east1-d` ``` Enable the required APIs: ```bash gcloud services enable deploymentmanager.googleapis.com --project=${PROJECT_ID} && \ gcloud services enable logging.googleapis.com --project=${PROJECT_ID} && \ gcloud services enable compute.googleapis.com --project=${PROJECT_ID} ``` Create the deployment with the following command. You can explore the `max-nvidia-gcp.jinja` file for more information on the Deployment Manager configuration. :::note Deployment naming The deployment name must be **unique** so please be sure to change the `DEPLOYMENT_NAME` if you create multiple deployments. ::: ```bash export DEPLOYMENT_NAME="max-serve-deployment" export INSTANCE_NAME="max-serve-instance" gcloud deployment-manager deployments create ${DEPLOYMENT_NAME} \ --template max-nvidia-gcp.jinja \ --properties "\ instanceName:${INSTANCE_NAME},\ zone:${ZONE},\ machineType:g2-standard-8,\ acceleratorType:nvidia-l4,\ acceleratorCount:1,\ sourceImage:common-cu123-v20240922-ubuntu-2204-py310,\ huggingFaceHubToken:${HF_TOKEN},\ huggingFaceRepoId:modularai/Llama-3.1-8B-Instruct-GGUF" \ --project ${PROJECT_ID} ``` First, navigate to the Azure directory: ```bash cd azure/nvidia ``` Set the region: ```bash export REGION="REGION" # example `westus3` ``` Then, create the resource group: :::note Resource group and deployment naming If you receive an error about resource group location conflicts, it means the resource group already exists in a different location. You can either: - Use a new resource group name - Use the existing resource group's location Additionally, the deployment name must be **unique** so please be sure to change the `DEPLOYMENT_NAME` if you create multiple deployments. ::: ```bash export RESOURCE_GROUP_NAME="maxServeResourceGroup" export DEPLOYMENT_NAME="maxServeDeployment" az group create --name ${RESOURCE_GROUP_NAME} --location ${REGION} ``` Check the status of the resource group: ```bash az group show -n ${RESOURCE_GROUP_NAME} --query properties.provisioningState -o tsv ``` Create and encode the startup script: ```bash STARTUP_SCRIPT='#!/bin/bash sudo usermod -aG docker $USER sudo systemctl restart docker sleep 10 HF_TOKEN=$1 HUGGING_FACE_REPO_ID=${2:-modularai/Llama-3.1-8B-Instruct-GGUF} sudo docker run -d \ --restart unless-stopped \ --env "HF_TOKEN=${HF_TOKEN}" \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ --gpus 1 \ -p 80:8000 \ --ipc=host \ modular/max-nvidia-full:latest \ --model-path ${HUGGING_FACE_REPO_ID}' export STARTUP_SCRIPT=$(echo "$STARTUP_SCRIPT" | base64) ``` Then, create the deployment: :::note NVIDIA license agreement You may be required to accept the Azure Marketplace image terms for the NVIDIA AI enterprise image: ```bash az vm image terms accept --urn nvidia:nvidia-ai-enterprise:nvaie_gpu_1_gen2:latest ``` ::: :::caution Set an admin password Replace `YOUR-SECURE-PASSWORD-123` with your own secure password to be able to `ssh` into the VM that we will use later. ::: ```bash export VM_PASSWORD="YOUR-SECURE-PASSWORD-123" az deployment group create \ --name ${DEPLOYMENT_NAME} \ --resource-group ${RESOURCE_GROUP_NAME} \ --template-file max-nvidia-azure.json \ --parameters \ adminUsername="azureuser" \ adminPassword=${VM_PASSWORD} \ vmSize="Standard_NV36ads_A10_v5" \ osDiskSizeGB=128 \ vnetAddressPrefix="10.0.0.0/16" \ subnetAddressPrefix="10.0.0.0/24" \ startupScript="${STARTUP_SCRIPT}" \ location="${REGION}" ``` First, navigate to the Azure directory: ```bash cd azure/amd ``` Set the region: ```bash export REGION="REGION" # example `westus3` ``` Then, create the resource group: :::note Resource group and deployment naming If you receive an error about resource group location conflicts, it means the resource group already exists in a different location. You can either: - Use a new resource group name - Use the existing resource group's location Additionally, the deployment name must be **unique** so please be sure to change the `DEPLOYMENT_NAME` if you create multiple deployments. ::: ```bash export RESOURCE_GROUP_NAME="maxServeResourceGroup" export DEPLOYMENT_NAME="maxServeDeployment" az group create --name ${RESOURCE_GROUP_NAME} --location ${REGION} ``` Check the status of the resource group: ```bash az group show -n ${RESOURCE_GROUP_NAME} --query properties.provisioningState -o tsv ``` Create and encode the startup script: ```bash STARTUP_SCRIPT='#!/bin/bash sudo usermod -aG docker $USER sudo systemctl restart docker sleep 10 HF_TOKEN=$1 HUGGING_FACE_REPO_ID=${2:-modularai/Llama-3.1-8B-Instruct-GGUF} sudo docker run -d \ --restart unless-stopped \ --env "HF_TOKEN=${HF_TOKEN}" \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ -p 80:8000 \ --ipc=host \ --device /dev/kfd \ --device /dev/dri \ modular/max-amd:latest \ --model-path ${HUGGING_FACE_REPO_ID}' export STARTUP_SCRIPT=$(echo "$STARTUP_SCRIPT" | base64) ``` Then, create the deployment: :::caution Set an admin password Replace `YOUR-SECURE-PASSWORD-123` with your own secure password to be able to `ssh` into the VM that we will use later. ::: ```bash export VM_PASSWORD="YOUR-SECURE-PASSWORD-123" az deployment group create \ --name ${DEPLOYMENT_NAME} \ --resource-group ${RESOURCE_GROUP_NAME} \ --template-file max-amd-azure.json \ --parameters \ adminUsername="azureuser" \ adminPassword=${VM_PASSWORD} \ vmSize="Standard_ND96isr_MI300X_v5" \ osDiskSizeGB=256 \ vnetAddressPrefix="10.0.0.0/16" \ subnetAddressPrefix="10.0.0.0/24" \ startupScript="${STARTUP_SCRIPT}" \ location="${REGION}" ``` ### 2. Wait for resources to be ready In this step, we'll wait for the resources to be ready. Stack and deployment creation may take some time to complete. ```bash aws cloudformation wait stack-create-complete \ --stack-name ${STACK_NAME} \ --region ${REGION} ``` ```bash gcloud deployment-manager deployments describe ${DEPLOYMENT_NAME} \ --project=${PROJECT_ID} ``` Wait for the deployment to be completed and report its status: ```bash az deployment group wait \ --name ${DEPLOYMENT_NAME} \ --resource-group ${RESOURCE_GROUP_NAME} \ --created ``` ### 3. Retrieve instance information After the resources are deployed, you'll need to get the instance information, such as the public DNS or IP address that we will use to test the endpoint. ```bash INSTANCE_ID=$(aws cloudformation describe-stacks --stack-name ${STACK_NAME} \ --query "Stacks[0].Outputs[?OutputKey=='InstanceId'].OutputValue" \ --output text \ --region ${REGION}) PUBLIC_IP=$(aws ec2 describe-instances --instance-ids ${INSTANCE_ID} \ --query 'Reservations[0].Instances[0].PublicIpAddress' \ --output text \ --region ${REGION}) echo "Instance ID: ${INSTANCE_ID}" echo "Public IP: ${PUBLIC_IP}" aws ec2 wait instance-running --instance-ids ${INSTANCE_ID} --region ${REGION} ``` First, check if the firewall rule already exists: ```bash EXISTING_RULE=$(gcloud compute firewall-rules list \ --filter="name=allow-http" \ --format="value(name)" \ --project=${PROJECT_ID}) if [ -z "$EXISTING_RULE" ]; then echo "Creating firewall rule..." gcloud compute firewall-rules create allow-http \ --allow tcp:80 \ --source-ranges 0.0.0.0/0 \ --target-tags http-server \ --description "Allow HTTP traffic on port 80" \ --project=${PROJECT_ID} else echo "Firewall rule 'allow-http' already exists" fi ``` Check if the instance exists and tag it with `http-server`: ```bash INSTANCE_EXISTS=$(gcloud compute instances list \ --filter="name=${INSTANCE_NAME}" \ --format="value(name)" \ --project=${PROJECT_ID}) if [ -n "$INSTANCE_EXISTS" ]; then echo "Adding tags to instance ${INSTANCE_NAME}" gcloud compute instances add-tags "${INSTANCE_NAME}" \ --project=${PROJECT_ID} \ --zone "${ZONE}" \ --tags http-server else echo "Error: Instance ${INSTANCE_NAME} not found" exit 1 fi ``` Then, get the public IP: ```bash PUBLIC_IP=$(gcloud compute instances describe "${INSTANCE_NAME}" \ --zone "${ZONE}" \ --format="get(networkInterfaces[0].accessConfigs[0].natIP)" \ --project=${PROJECT_ID}) echo "Public IP: $PUBLIC_IP" ``` ```bash PUBLIC_IP=$(az network public-ip show \ --resource-group ${RESOURCE_GROUP_NAME} \ --name maxServePublicIP \ --query ipAddress -o tsv) echo "Public IP: ${PUBLIC_IP}" ``` ### 4. Test the endpoint 1. Wait until the server is ready to test the endpoint It will take some time for the stack or deployment to pull the MAX Docker image and set it up for serving. We need to wait for the Docker logs to appear and then make sure that the Docker container is running on port `8000`. The server is ready when you see the following log: ```output Server ready on http://0.0.0.0:8000 ``` We provide a simple script to monitor the startup progress and notify you when the server is ready. For AWS, you can see the logs in the AWS CloudWatch UI within the log group `/aws/ec2/${STACK_NAME}-logs` and log stream `instance-logs`. Alternatively, you can use the provided bash script to monitor the logs until the server is ready: ```bash bash notify.sh ${REGION} ${STACK_NAME} ${PUBLIC_IP} ``` For GCP, first make sure that the Docker container is running on port `8000`. You can view the logs in the Compute Engine VM instances UI. Within the UI, choose **Observability**, then choose **Logs**. Alternatively, you can use the provided bash script to monitor the logs until the server is ready: ```bash bash notify.sh ${PROJECT_ID} ${INSTANCE_NAME} ${ZONE} ${PUBLIC_IP} ``` For Azure, you can monitor the Docker container status (running on port `8000`) using one of the following methods: #### Option 1: Use the monitoring script 1. Install the required dependencies for the monitoring script: - Install [sshpass](https://www.cyberciti.biz/faq/noninteractive-shell-script-ssh-password-provider/) on your local machine to enable automated SSH password authentication 2. Set up and run the monitoring script: ```bash bash notify.sh ${RESOURCE_GROUP_NAME} ${VM_PASSWORD} ${PUBLIC_IP} ``` #### Option 2: Manual SSH access 1. Connect to the VM: ```bash ssh azureuser@$PUBLIC_IP ``` > **Note:** Use the password you set previously when creating the deployment. 2. View the startup logs: ```bash sudo cat /var/log/azure/custom-script/handler.log sudo cat /var/lib/waagent/custom-script/download/0/stdout sudo cat /var/lib/waagent/custom-script/download/0/stderr sudo docker logs $(docker ps -q -f ancestor=modular/max-nvidia-full:latest) ``` > **Note:** Use the container name `modular/max-amd:latest` if you deployed MAX on an AMD instance. Both methods will help you confirm that the server is running correctly. The logs will show the startup progress and any potential issues that need to be addressed. 2. When the server is ready, use the public IP address that we obtained from the previous step to test the endpoint with the following `curl` request: :::tip After the server starts, there may be a brief delay before the cloud provider exposes the public IP address. If you receive an error, please wait approximately one minute and try again. ::: ```bash curl -N http://$PUBLIC_IP/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "modularai/Llama-3.1-8B-Instruct-GGUF", "stream": true, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the World Series in 2020?"} ] }' | grep -o '"content":"[^"]*"' | sed 's/"content":"//g' | sed 's/"//g' | tr -d '\n' | sed 's/\\n/\n/g' ``` :::note Benchmarking MAX You can also use the public IP address of your deployed MAX endpoint to benchmark the performance of Llama 3.1. MAX includes a benchmarking script that allows you to evaluate throughput, latency, and GPU utilization metrics. For more detailed instructions on benchmarking, see the [`max benchamrk` docs](/max/cli/benchmark). ::: ### 5. Delete the cloud resources Cleaning up resources to avoid unwanted costs is critical. Use the following commands to delete resources for each platform. This section provides steps to safely terminate all resources used in the tutorial. First, delete the stack: ```bash aws cloudformation delete-stack --stack-name ${STACK_NAME} ``` Wait for the stack to be deleted: ```bash aws cloudformation wait stack-delete-complete \ --stack-name ${STACK_NAME} \ --region ${REGION} ``` ```bash gcloud deployment-manager deployments delete ${DEPLOYMENT_NAME} \ --project=${PROJECT_ID} ``` ```bash az group delete --name ${RESOURCE_GROUP_NAME} ``` ### Cost estimate When deploying Llama 3 in a cloud environment, several cost factors come into play: **Primary cost components:** - **Compute Resources**: GPU instances (like AWS `g5.4xlarge`, GCP `g2-standard-8`, or Azure `Standard_NV36ads_A10_v5`) form the bulk of the costs - **Network Transfer**: Costs associated with data ingress/egress, which is critical for high-traffic applications - **Storage**: Expenses for boot volumes and any additional storage requirements - **Additional Services**: Costs for logging, monitoring, and other supporting cloud services For detailed cost estimates specific to your use case, we recommend using these official pricing calculators: - [AWS Pricing Calculator](https://calculator.aws) - [GCP Pricing Calculator](https://cloud.google.com/products/calculator) - [Azure Pricing Calculator](https://azure.microsoft.com/en-us/pricing/calculator/) :::tip Cloud cost optimization tips: - Consider using spot/preemptible instances for non-critical workloads - Implement auto-scaling to match resource allocation with demand - Monitor and optimize network usage patterns - Set up cost alerts and budgets to avoid unexpected charges Remember to factor in your expected usage patterns, regional pricing differences, and any applicable enterprise discounts when calculating total cost of ownership (TCO). ::: ## Next steps Congratulations on successfully running MAX Pipelines locally and deploying Llama 3 to the cloud! 🎉 To stay up to date with new releases, [join our community](https://www.modular.com/community). And if you're interested in becoming a design partner to get early access and give us feedback, please [contact us](https://www.modular.com/request-demo). --- ## Basic operations When you build a neural network model, you need to define what computations happen at each step: multiplying inputs by weights, applying activation functions, computing loss, and so on. Operations are the functions that perform these computations on tensors. MAX provides multiple ways to call operations on tensors: - **Python operators**: Use standard operators like `+`, `-`, `*`, `/`, `@`, and `**` for common arithmetic and linear algebra operations. - **Tensor methods**: Call operations directly on [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) objects, like `x.sum()`, `x.reshape([2, 3])`, or `x.transpose(0, 1)`. - **Functional API**: Call operations from [`max.functional`](/max/api/python/functional) that take your tensor as input, such as `relu(x)` or `concat([a, b])`. Use these for activation functions, multi-tensor operations, or explicit graph construction. ### When to use functional API While tensor methods are more idiomatic for core operations, you'll need the functional API for activation functions, multi-tensor operations, and explicit graph construction. The functional API provides operations as standalone functions imported from [`max.functional`](/max/api/python/functional). Use functional operations (`F.*`) for: - **Activation functions**: Operations like [`F.relu()`](/max/api/python/functional#max.functional.relu), [`F.sigmoid()`](/max/api/python/functional#max.functional.sigmoid), and [`F.tanh()`](/max/api/python/functional#max.functional.tanh) don't have tensor method equivalents. - **Multi-tensor operations**: Operations that require multiple tensor inputs, like [`F.concat()`](/max/api/python/functional#max.functional.concat). - **Explicit graph construction**: When building computation graphs explicitly, functional operations provide more direct control. ## Perform arithmetic operations You can use standard Python operators for basic arithmetic on tensors. The `+`, `-`, `*`, and `/` operators perform element-wise operations on tensors: {/* @sync: _examples/basic-ops/arithmetic.py */} ```python from max.tensor import Tensor a = Tensor.constant([1.0, 2.0, 3.0]) b = Tensor.constant([4.0, 5.0, 6.0]) # Element-wise operations addition = a + b subtraction = a - b multiplication = a * b division = a / b print(addition) print(multiplication) ``` The expected output is: ```output TensorType(dtype=float32, shape=[Dim(3)], device=cpu:0): [5.0, 7.0, 9.0] TensorType(dtype=float32, shape=[Dim(3)], device=cpu:0): [4.0, 10.0, 18.0] ``` For more complex mathematical operations, MAX provides several approaches. In this example, `abs()` finds the absolute value, and the `**` operator performs exponentiation, which work seamlessly with tensors. [`F.sqrt()`](/max/api/python/functional#max.functional.sqrt) uses the functional API since there's no built-in function or tensor method for square root: {/* @sync: _examples/basic-ops/tensor_math_ops.py */} ```python import max.functional as F from max.tensor import Tensor x = Tensor.constant([1.0, -4.0, 9.0, -16.0]) # Built-in functions using dunder methods absolute = abs(x) # Uses __abs__ power = x ** 2 # Uses __pow__ # Functional API for operations without built-ins square_root = F.sqrt(abs(x)) # F.sqrt requires non-negative values print(f"Absolute value: {absolute}") print(f"Power (x**2): {power}") print(f"Square root: {square_root}") ``` The expected output is: ```output Absolute value: TensorType(dtype=float32, shape=[Dim(4)], device=cpu:0): [1.0, 4.0, 9.0, 16.0] Power (x**2): TensorType(dtype=float32, shape=[Dim(4)], device=cpu:0): [1.0, 16.0, 81.0, 256.0] Square root: TensorType(dtype=float32, shape=[Dim(4)], device=cpu:0): [1.0, 2.0, 3.0, 4.0] ``` :::note Functional API equivalents Some mathematical operations require the functional API. [`F.exp(x)`](/max/api/python/functional#max.functional.exp) computes e^x where e is Euler's number (approximately 2.718), while [`F.log(x)`](/max/api/python/functional#max.functional.log) computes the natural logarithm. ::: ## Manipulate tensor shapes Shape operations reorganize tensor data without changing the underlying values. These operations are essential for preparing data for different layers in neural networks. ### Reshape tensors The [`reshape()`](/max/api/python/tensor#max.tensor.Tensor.reshape) method changes the shape of a tensor while preserving the total number of elements. The following example transforms a 12-element vector into different layouts—the total number of elements remains constant across all shapes: {/* @sync: _examples/basic-ops/tensor_reshape.py */} ```python from max.tensor import Tensor # Create a 1-D tensor with 12 elements x = Tensor.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) print(f"Original shape: {x.shape}") # Reshape to 3x4 matrix matrix = x.reshape([3, 4]) print(f"Reshaped to 3x4: {matrix.shape}") print(matrix) # Reshape to 2x2x3 cube cube = x.reshape([2, 2, 3]) print(f"Reshaped to 2x2x3: {cube.shape}") ``` The expected output is: ```output Original shape: [Dim(12)] Reshaped to 3x4: [Dim(3), Dim(4)] TensorType(dtype=float32, shape=[Dim(3), Dim(4)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0] Reshaped to 2x2x3: [Dim(2), Dim(2), Dim(3)] ``` ### Transpose tensors The [`transpose()`](/max/api/python/tensor#max.tensor.Tensor.transpose) method swaps two dimensions of a tensor: {/* @sync: _examples/basic-ops/tensor_transpose.py */} ```python from max.tensor import Tensor # Create a 2x3 matrix x = Tensor.constant([[1, 2, 3], [4, 5, 6]]) print(f"Original shape: {x.shape}") print(x) # Transpose to 3x2 y = x.transpose(0, 1) print(f"Transposed shape: {y.shape}") print(y) ``` The expected output is: ```output Original shape: [Dim(2), Dim(3)] TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] Transposed shape: [Dim(3), Dim(2)] TensorType(dtype=float32, shape=[Dim(3), Dim(2)], device=cpu:0): [1.0, 4.0, 2.0, 5.0, 3.0, 6.0] ``` The element at position `[i, j]` in the original tensor moves to position `[j, i]` in the transposed tensor. For the common case of transposing the last two dimensions, you can use the `.T` property: {/* @sync: _examples/basic-ops/tensor_transpose_t.py */} ```python from max.tensor import Tensor # Create a 2x3 matrix x = Tensor.constant([[1, 2, 3], [4, 5, 6]]) # Transpose last two dimensions using .T y = x.T print(f"Transposed shape: {y.shape}") print(y) ``` The expected output is: ```output Transposed shape: [Dim(3), Dim(2)] TensorType(dtype=float32, shape=[Dim(3), Dim(2)], device=cpu:0): [1.0, 4.0, 2.0, 5.0, 3.0, 6.0] ``` The `.T` property is equivalent to calling `transpose(-1, -2)` and works on tensors of any rank. When you need to rearrange dimensions in more complex ways, use [`permute()`](/max/api/python/tensor#max.tensor.Tensor.permute) to specify a new order for all dimensions. This is useful for converting between different layout conventions. In the following example, [`permute(0, 2, 1)`](/max/api/python/tensor#max.tensor.Tensor.permute) rearranges the dimensions so dimension 0 stays in place, dimension 2 moves to position 1, and dimension 1 moves to position 2: {/* @sync: _examples/basic-ops/tensor_permute.py */} ```python from max.tensor import Tensor # Create a 3D tensor (batch_size=2, channels=3, length=4) x = Tensor.constant([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]) print(f"Original shape: {x.shape}") # Rearrange to (batch, length, channels) y = x.permute(0, 2, 1) print(f"Permuted shape: {y.shape}") ``` The expected output is: ```output Original shape: [Dim(2), Dim(3), Dim(4)] Permuted shape: [Dim(2), Dim(4), Dim(3)] ``` ### Concatenate tensors The [`F.concat()`](/max/api/python/functional#max.functional.concat) function joins multiple tensors along a specified dimension. This operation requires the functional API since it operates on multiple tensors: {/* @sync: _examples/basic-ops/concat.py */} ```python import max.functional as F from max.tensor import Tensor a = Tensor.constant([[1, 2], [3, 4]]) b = Tensor.constant([[5, 6], [7, 8]]) # Concatenate along axis 0 (rows) vertical = F.concat([a, b], axis=0) print(f"Concatenated along axis 0: {vertical.shape}") print(vertical) # Concatenate along axis 1 (columns) horizontal = F.concat([a, b], axis=1) print(f"Concatenated along axis 1: {horizontal.shape}") print(horizontal) ``` The expected output is: ```output Concatenated along axis 0: [Dim(4), Dim(2)] TensorType(dtype=float32, shape=[Dim(4), Dim(2)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] Concatenated along axis 1: [Dim(2), Dim(4)] TensorType(dtype=float32, shape=[Dim(2), Dim(4)], device=cpu:0): [1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0] ``` Concatenating along axis 0 stacks tensors vertically, while concatenating along axis 1 joins them horizontally. Use [`F.concat()`](/max/api/python/functional#max.functional.concat) since there's no tensor method equivalent for multi-tensor operations. ## Apply reduction operations Reduction operations aggregate tensor values along one or more dimensions, producing smaller tensors or scalars. Use tensor methods for reductions: {/* @sync: _examples/basic-ops/tensor_reductions.py */} ```python import max.functional as F from max.tensor import Tensor x = Tensor.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Reduce along different dimensions sum_all = x.sum() # Sum all elements sum_rows = x.sum(axis=0) # Sum each column sum_cols = x.sum(axis=1) # Sum each row print(f"Sum of all elements: {sum_all}") print(f"Sum of each column: {sum_rows}") print(f"Sum of each row: {sum_cols}") # Other reductions mean_val = x.mean() max_val = x.max() min_val = F.min(x) # min() requires functional API print(f"Mean: {mean_val}") print(f"Max: {max_val}") print(f"Min: {min_val}") ``` The expected output is: ```output Sum of all elements: TensorType(dtype=float32, shape=[], device=cpu:0): 21.0 Sum of each column: TensorType(dtype=float32, shape=[Dim(3)], device=cpu:0): [5.0, 7.0, 9.0] Sum of each row: TensorType(dtype=float32, shape=[Dim(2)], device=cpu:0): [6.0, 15.0] Mean: TensorType(dtype=float32, shape=[], device=cpu:0): 3.5 Max: TensorType(dtype=float32, shape=[], device=cpu:0): 6.0 Min: TensorType(dtype=float32, shape=[], device=cpu:0): 1.0 ``` When you specify an axis, the reduction operates along that dimension. Without an axis, the reduction operates on all elements, producing a scalar. Common reduction operations include: - [`sum()`](/max/api/python/tensor#max.tensor.Tensor.sum): Sum of elements (tensor method) - [`mean()`](/max/api/python/tensor#max.tensor.Tensor.mean): Average of elements (tensor method) - [`max()`](/max/api/python/tensor#max.tensor.Tensor.max): Maximum value (tensor method) - [`F.min()`](/max/api/python/functional#max.functional.min): Minimum value (functional API only) :::note Keepdims behavior Unlike NumPy and PyTorch, MAX reduction operations keep dimensions by default. To remove the reduced dimension, specify `keepdims=False`. ::: :::tip Functional API equivalents All reduction operations are also available as functional API calls like [`F.sum(x)`](/max/api/python/functional#max.functional.sum). Use these when building explicit graphs or when you prefer function-style syntax. ::: ## Perform matrix operations Matrix operations are fundamental to neural networks. MAX provides efficient implementations for common matrix operations. Use the `@` operator for matrix multiplication: {/* @sync: _examples/basic-ops/tensor_matmul.py */} ```python from max.tensor import Tensor # Create two matrices x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]]) # 2x2 w = Tensor.constant([[5.0, 6.0], [7.0, 8.0]]) # 2x2 # Matrix multiply using @ operator result = x @ w print("Matrix multiplication result:") print(result) ``` The expected output is: ```output Matrix multiplication result: TensorType(dtype=float32, shape=[Dim(2), Dim(2)], device=cpu:0): [19.0, 22.0, 43.0, 50.0] ``` The `@` operator performs standard matrix multiplication (using the `__matmul__` dunder method). The result is computed as `result[i, j] = sum(x[i, k] * w[k, j])`. :::tip Functional API equivalent Matrix multiplication is also available as [`F.matmul(x, w)`](/max/api/python/functional#max.functional.matmul). Use this when building explicit graphs or when you prefer function-style syntax. ::: ## Add activation functions Activation functions are only available through the functional API. [`F.relu()`](/max/api/python/functional#max.functional.relu) sets negative values to zero, [`F.sigmoid()`](/max/api/python/functional#max.functional.sigmoid) maps values to (0, 1), and [`F.tanh()`](/max/api/python/functional#max.functional.tanh) maps values to (-1, 1): {/* @sync: _examples/basic-ops/functional_activations.py */} ```python import max.functional as F from max.tensor import Tensor x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]) # Apply activation functions relu_output = F.relu(x) sigmoid_output = F.sigmoid(x) tanh_output = F.tanh(x) print(f"ReLU: {relu_output}") print(f"Sigmoid: {sigmoid_output}") print(f"Tanh: {tanh_output}") ``` The expected output is: ```output ReLU: TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [0.0, 0.0, 0.0, 1.0, 2.0, 3.0] Sigmoid: TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [0.119, 0.269, 0.5, 0.731, 0.881, 0.953] Tanh: TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [-0.964, -0.762, 0.0, 0.762, 0.964, 0.995] ``` ## Generate random tensors The [`max.random`](/max/api/python/random) module provides functions for creating tensors with random values. Random tensors are essential for weight initialization and data augmentation. ### Create random values [`random.uniform()`](/max/api/python/random#max.random.uniform) generates values uniformly distributed between `low` and `high`, while [`random.normal()`](/max/api/python/random#max.random.normal) generates values from a Gaussian distribution with the specified mean and standard deviation: {/* @sync: _examples/basic-ops/random_basic.py */} ```python from max import random # Uniform distribution between 0 and 1 uniform_tensor = random.uniform([3, 3], low=0.0, high=1.0) print("Uniform distribution:") print(uniform_tensor) # Normal (Gaussian) distribution normal_tensor = random.normal([3, 3], mean=0.0, std=1.0) print("\nNormal distribution:") print(normal_tensor) ``` The expected output is (values will vary since they're random): ```output Uniform distribution: TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [0.234, 0.789, 0.456, 0.123, 0.890, 0.567, 0.345, 0.678, 0.901] Normal distribution: TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [-0.52, 1.18, -0.73, 0.31, -1.04, 0.65, 0.19, -0.38, 0.87] ``` ### Initialize weights The random module provides specialized weight initialization functions following common neural network initialization schemes. [`random.xavier_uniform()`](/max/api/python/random#max.random.xavier_uniform) and [`random.kaiming_uniform()`](/max/api/python/random#max.random.kaiming_uniform) generate weights with distributions designed to maintain stable gradients during training. Xavier initialization works well with sigmoid and tanh activations, while Kaiming initialization is optimized for ReLU activations: {/* @sync: _examples/basic-ops/random_init.py */} ```python from max import random # Xavier/Glorot initialization (for sigmoid/tanh activations) xavier_weights = random.xavier_uniform([3, 3]) print("Xavier uniform initialization:") print(xavier_weights) # He/Kaiming initialization (for ReLU activations) he_weights = random.kaiming_uniform([3, 3]) print("\nKaiming uniform initialization:") print(he_weights) ``` The expected output is (values will vary): ```output Xavier uniform initialization: TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [-0.432, 0.789, -0.156, 0.234, -0.678, 0.345, -0.901, 0.567, -0.123] Kaiming uniform initialization: TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [-0.721, 0.543, -0.234, 0.876, -0.456, 0.198, -0.654, 0.321, -0.789] ``` ## Build layers You can combine operations to implement neural network layers from scratch. The following example shows a simple linear layer: the `linear_layer` function uses the `@` operator for matrix multiplication and the `+` operator for bias addition, while the activation step uses [`F.relu()`](/max/api/python/functional#max.functional.relu) from the functional API. Pre-built layers like [`nn.Linear`](/max/api/python/nn#max.nn.Linear) work this way internally. Understanding operations lets you build custom layers when you need behavior beyond what standard layers provide: {/* @sync: _examples/basic-ops/compose_linear_layer.py */} ```python import max.functional as F from max import random from max.dtype import DType from max.tensor import Tensor def linear_layer(x: Tensor, weights: Tensor, bias: Tensor) -> Tensor: """Apply a linear transformation: y = xW + b.""" # Matrix multiply input by weights output = x @ weights # Add bias term output = output + bias return output # Create input (batch_size=2, input_features=4) x = Tensor.constant([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]) # Initialize weights (input_features=4, output_features=3) weights = random.xavier_uniform([4, 3]) # Initialize bias (output_features=3) bias = Tensor.zeros([3], dtype=DType.float32) # Apply linear transformation output = linear_layer(x, weights, bias) print(f"Output shape: {output.shape}") print(output) # Add activation function (requires functional API) activated = F.relu(output) print(f"\nAfter ReLU: {activated}") ``` The expected output is (weight values will vary): ```output Output shape: [Dim(2), Dim(3)] TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [-0.234, 0.567, -0.123, -0.891, 1.234, -0.456] After ReLU: TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [0.0, 0.567, 0.0, 0.0, 1.234, 0.0] ``` ## Next steps Now that you understand basic operations, continue learning about building neural networks: - **[Building graphs](/max/graph/what-is-a-graph)**: Use operations in explicit graph construction for production deployment. - **[Neural network modules](/max/api/python/nn)**: Build models using the `max.nn` module with pre-built layers like `Linear`, `Conv2d`, and `ReLU`. - **[Custom operations](/max/develop/custom-ops)**: Implement your own operations in Mojo when the built-in operations don't meet your performance or functionality needs. --- ## Build an MLP block as a module import InstallModular from '@site/docs/_includes/install-modular.mdx'; import MDXListing from '@site/src/components/Listing/MDXListing'; Multilayer Perceptrons (MLPs) are a fundamental component of many neural networks. An MLP consists of a sequence of linear (fully connected) layers interspersed with non-linear activation functions, forming the backbone of many deep learning architectures. While you can build MLPs by manually composing layers, creating a dedicated, reusable `Module` approach for an MLP block offers better modularity and code organization, especially in larger projects using the MAX framework. In this tutorial, you'll create a flexible MLP block as a custom [`Module`](/max/api/python/nn/module#max.nn.module.Module). This block will allow you to specify input and output dimensions, the number and size of hidden layers, and the activation function. By the end, you'll have a reusable component you can easily integrate into various graphs written in MAX. Throughout this tutorial, you'll learn to define a class structure that inherits from `Module`, implement initialization and computation methods, create a helpful string representation for debugging, and configure your module with various parameters. These skills form the foundation for building custom, reusable components in MAX's neural network framework. :::note Modules automatically create computational graphs when called within a [`Graph`](/max/api/python/graph/Graph) context. This provides a familiar PyTorch-like interface while leveraging MAX's optimized graph execution. If you're new to MAX graphs, start with our tutorial on [getting started with MAX graphs](/max/develop/get-started-with-max-graph-in-python). ::: ## Set up Create a Python project to install our APIs and CLI tools: When you install the `modular` package, you'll get access to the `max` Python APIs. ## Define the `MLPBlock` class structure First, you'll define the basic structure of your custom module using the [`Module`](/max/api/python/nn/module#max.nn.module.Module) class. To create the `MLPBlock` module, you need to implement two methods: 1. `__init__(self, ...)`: The constructor, where you define the sub-layers (like `Linear`) and parameters the module will use. 2. `__call__(self, x)`: The method that defines the computation graph when data `x` (a [`TensorValue`](/max/api/python/graph/TensorValue)) is passed through the module. Start by creating a file named `mlp.py` and import the necessary libraries. Then, define the `MLPBlock` class structure: ```python title="mlp.py" from __future__ import annotations from collections.abc import Callable from typing import Any from max.dtype import DType from max.graph import DeviceRef, TensorValue, ops from max.nn import legacy as nn class MLPBlock(nn.Module): def __init__( # TODO: Add parameters ) -> None: super().__init__() ``` After importing the required components from [`dtype`](/max/api/python/dtype), [`graph`](/max/api/python/graph), and [`nn`](/max/api/python/nn) (neural network), you create your `MLPBlock` class that inherits from `Module`. The `__init__` method contains configuration parameters with the `super().__init__()` call in the constructor. ## Implement the `__init__` method Next, implement the constructor (`__init__`). This method takes the configuration parameters and creates the necessary layers and activation functions for your MLP. You'll use [`Sequential`](/max/api/python/nn/sequential) to create a sequential container for your layers to build the computation graph. Modify the `__init__` method in your `MLPBlock` class as follows: ```python title="mlp.py" class MLPBlock(nn.Module): def __init__( self, in_features: int, out_features: int, hidden_features: list[int] | None = None, activation: Callable[[TensorValue], TensorValue] | None = None, ) -> None: super().__init__() # Use empty list if no hidden features provided hidden_features = hidden_features or [] # Default to ReLU activation if none provided activation = activation or ops.relu # Build the sequence of layers layers = [] current_dim = in_features # Add hidden layers with their activations for i, h_dim in enumerate(hidden_features): layers.append( nn.Linear( in_dim=current_dim, out_dim=h_dim, dtype=DType.float32, device=DeviceRef.CPU(), has_bias=True, name=f"hidden_{i}", ) ) layers.append(activation) current_dim = h_dim # Add the final output layer layers.append( nn.Linear( in_dim=current_dim, out_dim=out_features, dtype=DType.float32, device=DeviceRef.CPU(), has_bias=True, name="output", ) ) # Create Sequential module with the layers self.model = nn.Sequential(layers) ``` Here's how this implementation works: You initialize an empty list `layers` to build your network structure. The code iterates through `hidden_features`, appending a [`Linear`](/max/api/python/nn/Linear) layer instance and the provided `activation` function for each hidden dimension, updating `current_dim` as you go. After processing all hidden layers, you add the output `Linear` layer. Finally, you create a [`Sequential`](/max/api/python/nn/sequential) module with these layers, which handles the sequential application of operations. Note that we use default values for `dtype` and `device` to simplify the interface. In a production environment, you might want to expose these parameters to allow for different data types and devices. ## Implement the `__call__` method In MAX, the `__call__` method is a special Python method that gets invoked when you use a module instance as if it were a function. For example, when you write `output = mlp_block(input_tensor)` in your code, Python automatically calls the `__call__` method. This is a key part of how MAX builds computation graphs—when you call a module with an input tensor, you're adding that module's operations to the computation graph that will eventually be compiled and executed. The `__call__` method defines how an input [`TensorValue`](/max/api/python/graph/TensorValue) flows through the module, building the computation graph. Since you're using `Sequential`, your implementation requires only two lines of code: ```python title="mlp.py" def __call__(self, x: TensorValue) -> TensorValue: return self.model(x) ``` This method takes an input `TensorValue` `x` and passes it through the `self.model` sequential container, which automatically applies each layer in sequence. ## Implement a custom string representation To better understand the structure of your MLP blocks when they're printed, you'll implement the `__repr__` method to display useful information about the number of layers in each block: ```python title="mlp.py" def __repr__(self) -> str: layers = list(self.model) linear_count = sum( 1 for layer in layers if layer.__class__.__name__ == "Linear" ) activation_count = len(layers) - linear_count return f"MLPBlock({linear_count} linear layers, {activation_count} activations)" ``` This method counts the linear layers and activation functions separately, making it clear how many of each type exist in your MLP block. ## Run the `MLPBlock` module Now, run the `MLPBlock` by creating instances with different configurations. Note that MAX uses a static graph representation that gets compiled before execution, which differs from frameworks like PyTorch where tensors flow dynamically through the network. The examples below show how to instantiate MLP blocks with various configurations. To actually execute these blocks with data, you would integrate them into a larger MAX graph execution context. Create a new file called `main.py` with the following code: ```python title="main.py" from max.graph import ops from mlp import MLPBlock if __name__ == "__main__": print("--- Simple MLP Block ---") # 1. Simple MLP (no hidden layers) simple_mlp = MLPBlock( in_features=10, out_features=20, hidden_features=[], activation=ops.relu, ) print(simple_mlp) print("-" * 30) # 2. MLP with one hidden layer print("--- MLP Block (1 Hidden Layer) ---") mlp_one_hidden = MLPBlock( in_features=10, out_features=5, hidden_features=[32], activation=ops.relu, ) print(mlp_one_hidden) print("-" * 30) # 3. Deeper MLP with multiple hidden layers and GELU print("--- Deeper MLP Block (3 Hidden Layers, GELU) ---") deep_mlp = MLPBlock( in_features=64, out_features=10, hidden_features=[128, 64, 32], activation=ops.gelu, ) print(deep_mlp) print("-" * 30) ``` Execute the `main.py` file. This will instantiate the `MLPBlock()` with various configurations and print their representations, showing the layers defined within. The following is the expected output: ```output --- Simple MLP Block --- MLPBlock(1 linear layers, 0 activations) ---------------------------------------- --- MLP Block (1 Hidden Layer) --- MLPBlock(2 linear layers, 1 activations) ---------------------------------------- --- Deeper MLP Block (3 Hidden Layers, GELU) --- MLPBlock(4 linear layers, 3 activations) ---------------------------------------- ``` - The simple MLP has 1 linear layer and 0 activations (since there are no hidden layers) - The MLP with one hidden layer has 2 linear layers (input→hidden, hidden→output) and 1 activation - The deeper MLP has 4 linear layers (input→hidden1, hidden1→hidden2, hidden2→hidden3, hidden3→output) and 3 activations ## Conclusion In this tutorial, you successfully created a reusable `MLPBlock` class in MAX and learned how to integrate it with computational graphs. You learned how to: 1. Define the class structure inheriting from `Module` 2. Implement the `__init__` method to dynamically create a sequence of linear layers (`nn.Linear`) and activation functions, then wrap them in `nn.Sequential` 3. Implement a simple `__call__` method that leverages the sequential container 4. Create a custom string representation for debugging 5. Instantiate and inspect the custom module with various configurations This `MLPBlock` provides a clean and modular way to incorporate standard MLP structures into your MAX projects. You can now easily modify it further, perhaps by adding layer normalization layers, or experiment with different activation functions from `max.graph.ops`. This pattern of creating reusable modules is fundamental to building complex and maintainable models in MAX. ## Next steps export const docs = [ '../../develop/get-started-with-max-graph-in-python.mdx', '../../develop/build-custom-ops.mdx', '../../graph/quantize.mdx', ]; --- ## Build custom ops for GPUs import Requirements from '@site/src/components/Requirements'; import { requirementsWithGPU } from '@site/docs/max/requirements'; import MDXListing from '@site/src/components/Listing/MDXListing'; Mojo is our not-so-secret weapon to achieve architecture-independent performance for all types of AI workloads. In this tutorial, you'll learn to write custom graph operations (ops) in Mojo that run on GPUs and CPUs, and then load them into a MAX graph written in Python. We'll start with a simple custom op that just adds `1` to every element in the graph's input tensor, using an API that abstracts-away the CPU and GPU device management. Then you'll learn to write specialized functions for CPUs and GPUs. (GPU functions that run in parallel are also known as _kernels_.) Before you begin, you should have a basic understanding of MAX graphs, which are computation graphs written in Python. These graphs are the foundation for high-performance AI models that run on MAX. To learn more, see our tutorial to [get started with MAX graphs in Python](/max/develop/get-started-with-max-graph-in-python). It also helps if you know the [Mojo language basics](/mojo/manual/basics). ## Requirements Although these examples work on both CPUs and GPUs, in order for the GPU code paths to run, your system must meet the [GPU requirements](/max/faq/#gpu-requirements). ## Get the examples This tutorial is a walkthrough of a couple code examples from [our GitHub repo](https://github.com/modular/modular). Start by cloning and running one of the examples: 1. Clone the repo: ```sh git clone https://github.com/modular/modular.git ``` 2. To ensure you have a compatible developer environment, we recommend using [`pixi`](https://pixi.sh/latest/) to create a virtual environment and manage the package dependencies. If you don't have it, install it: ```sh curl -fsSL https://pixi.sh/install.sh | sh ``` Then restart your terminal for the changes to take effect. 3. Make sure everything works by running the first example: ```sh cd modular/examples/custom_ops pixi run python addition.py ``` The exact output will vary based on random initialization of the input tensor, but the "Graph result" and "Expected result" should match: ```output Graph result: [[1.7736697 1.4688652 1.7971799 1.4553597 1.8967733 1.3691401 1.1297637 1.7047229 1.1314526 1.3924606] # ... shorten for brevity Expected result: [[1.7736697 1.4688652 1.7971799 1.4553597 1.8967733 1.3691401 1.1297637 1.7047229 1.1314526 1.3924606] # ... shorten for brevity ``` Now let's dive into the implementation details to understand how this custom operation works under the hood. ## Example 1: Learn the custom op basics To learn how to create a custom op, let's look as simple "hello world" example that adds `1` to each element of a tensor—doing so in parallel on a GPU, if available. ### Define the custom op in Mojo Take a look at the custom op defined in [`custom_ops/kernels/add_one.mojo`](https://github.com/modular/modular/tree/main/max/examples/custom_ops/kernels/add_one.mojo). You'll see a Mojo [struct](/mojo/manual/structs) called `AddOne` with an `execute()` function. Every custom op must be defined with this general format, as described below. Depending on the purpose of your custom op, the `execute()` function will accept zero or more inputs and produce one or more outputs, as specified by the function arguments. Let's inspect the struct and function signatures: ```mojo title="kernels/add_one.mojo" @compiler.register("add_one") struct AddOne: @staticmethod fn execute[ target: StaticString, ]( output: OutputTensor, x: InputTensor[dtype = output.dtype, rank = output.rank], ctx: DeviceContextPtr, ) raises: # See below for the rest ``` The struct must include the [`@compiler.register()`](/max/api/mojo-decorators/compiler-register) decorator, which registers the custom op with MAX. The `add_one` name we set here is the name we'll use to add the op to our Python graph in the next section. The rest of the `execute()` signature describes the custom op's graph node to the graph compiler: the parameters, inputs, and outputs: - There's one [compile-time parameter](/mojo/manual/parameters/#parameterized-functions), `target`, which tells the function what kind of hardware it's being compiled for (either `"cpu"` or `"gpu"`; we'll use this in the next example). - The runtime arguments include the op's inputs and outputs, which take the form of [`InputTensor`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/#inputtensor) and [`OutputTensor`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/#outputtensor), respectively. These are specialized versions of the [`ManagedTensorSlice`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/ManagedTensorSlice) type, which represents a tensor of a specific rank and datatype whose memory is managed outside of the operation. :::note The `execute()` function must include `target` as the first parameter and `output` as the first argument. ::: Now let's look at the body of the `execute()` function, which is highlighted in the following code. The op's core computation that adds `1` to each element in the tensor happens in the `elementwise_add_one()` closure function. ```mojo title="kernels/add_one.mojo" {11-18} @compiler.register("add_one") struct AddOne: @staticmethod fn execute[ target: StaticString, ]( output: OutputTensor, x: InputTensor[dtype = output.dtype, rank = output.rank], ctx: DeviceContextPtr, ) raises: @parameter @always_inline fn elementwise_add_one[ width: Int ](idx: IndexList[x.rank]) -> SIMD[x.dtype, width]: return x.load[width](idx) + 1 foreach[elementwise_add_one, target=target](output, ctx) ``` We call `elementwise_add_one()` using [`foreach()`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach/), which distributes an elementwise computation in parallel across all elements in the output tensor. At compile time, `foreach()` optimizes the computation for the hardware its running on, optimally distributing parallel workloads to make the most efficient use of computational resources. This means that this same code runs with optimal performance on CPU or GPU with no changes required. Also notice that we pass the `target` parameter to the `foreach()` function, which allows it to optimize for the hardware device. (We'll do this ourselves in the [next example](#example-2-write-device-specific-kernels).) ### Add the custom op to a Python graph Let's now look at the corresponding [`custom_ops/addition.py`](https://github.com/modular/modular/tree/main/max/examples/custom_ops/addition.py) file where we load the `add_one` custom op into a graph using [`ops.custom()`](/max/api/python/graph/ops#max.graph.ops.custom). Here's the code that specifies the path to the Mojo custom op and adds it to the graph: ```python title="addition.py" mojo_kernels = Path(__file__).parent / "kernels" rows = 5 columns = 10 dtype = DType.float32 device = CPU() if accelerator_count() == 0 else Accelerator() graph = Graph( "addition", forward=lambda x: ops.custom( name="add_one", device=DeviceRef.from_device(device), values=[x], out_types=[ TensorType( dtype=x.dtype, shape=x.tensor.shape, device=DeviceRef.from_device(device), ) ], )[0].tensor, input_types=[ TensorType( dtype, shape=[rows, columns], device=DeviceRef.from_device(device), ), ], custom_extensions=[mojo_kernels], ) ``` :::note Make sure the directory you pass to `custom_extensions` is a Mojo package containing an `__init__.mojo` file (which can be empty). ::: The [`Graph()`](/max/api/python/graph/Graph.md) takes an input tensor with five rows and ten columns, runs the custom `add_one` operation on it, and returns the result. That's basically it. We created a custom op and added it to a graph! Now we can run an inference, using [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession). We start by loading the graph onto the selected `device` (see above; either a CPU or GPU "accelerator"): ```python title="addition.py" session = InferenceSession( devices=[device], ) model = session.load(graph) ``` Finally, we generate some random data and pass it as input: ```python title="addition.py" # Fill an input matrix with random values. x_values = np.random.uniform(size=(rows, columns)).astype(np.float32) # Create a buffer and move it to the device (CPU or GPU). x = Buffer.from_numpy(x_values).to(device) # Run inference with the input tensor. result = model.execute(x)[0] # Copy values back to the CPU to be read. assert isinstance(result, Buffer) result = result.to(CPU()) ``` Notice that the [`Buffer`](/max/api/python/driver#max.driver.Buffer) is initially resident on the host (CPU), so we move it to the accelerator to be ready for use with the graph on that device (if the device is the host CPU and not an accelerator, the `to()` function is a no-op). Likewise, after we get results, we move the result back to the CPU to read it. As shown above, you can run this example as follows: ```sh pixi run python addition.py ``` ## Example 2: Write device-specific kernels Now let's write some GPU code of our own. The `add_one` custom op above uses [`foreach()`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach/) to run our custom op computation. This function is a device-independent abstraction to perform calculations on each element of a tensor, and it offers great out-of-the-box performance for a range of CPUs and GPUs. However, there might also be situations in which you want to write your own hardware-specific algorithms. So let's now look at the [`vector_addition.mojo`](https://github.com/modular/modular/tree/main/max/examples/custom_ops/kernels/vector_addition.mojo) custom op, which adds two vectors together in parallel on a GPU. This example uses a programming model that may be familiar if you've used general purpose GPU programming in CUDA or similar frameworks: We're going to write separate functions for GPUs and CPUs, and the GPU function (kernel) is written specifically for parallelism across GPU threads. The vector addition op in `kernels/vector_addition.mojo` looks like this: ```mojo title="kernels/vector_addition.mojo" @compiler.register("vector_addition") struct VectorAddition: @staticmethod fn execute[ target: StaticString, ]( output: OutputTensor[rank=1], lhs: InputTensor[dtype = output.dtype, rank = output.rank], rhs: InputTensor[dtype = output.dtype, rank = output.rank], ctx: DeviceContextPtr, ) raises: @parameter if target == "cpu": _vector_addition_cpu(output, lhs, rhs, ctx) elif target == "gpu": _vector_addition_gpu(output, lhs, rhs, ctx) else: raise Error("No known target:", target) ``` Using the [parametric `if`](/mojo/manual/decorators/parameter), Mojo checks at compile time if the target device has a GPU or not. If it does, it uses the `_vector_addition_gpu()` function; otherwise, it uses `_vector_addition_cpu()`. Compile-time specialization like this is a unique and powerful feature of Mojo that makes it easy to optimize code for specific hardware. The `_vector_addition_gpu()` function looks like this: ```mojo title="kernels/vector_addition.mojo" fn _vector_addition_gpu( output: ManagedTensorSlice[mut=True], lhs: ManagedTensorSlice[dtype = output.dtype, rank = output.rank], rhs: ManagedTensorSlice[dtype = output.dtype, rank = output.rank], ctx: DeviceContextPtr, ) raises: alias BLOCK_SIZE = 16 var gpu_ctx = ctx.get_device_context() var vector_length = output.dim_size(0) @parameter fn vector_addition_gpu_kernel(length: Int): var tid = block_dim.x * block_idx.x + thread_idx.x if tid < UInt(length): var idx = IndexList[output.rank](Int(tid)) var result = lhs.load[1](idx) + rhs.load[1](idx) output.store[1](idx, result) var num_blocks = ceildiv(vector_length, BLOCK_SIZE) gpu_ctx.enqueue_function_experimental[vector_addition_gpu_kernel]( vector_length, grid_dim=num_blocks, block_dim=BLOCK_SIZE ) ``` The `vector_addition_gpu_kernel()` closure function runs once per thread on the GPU, adding an element from the `lhs` vector to the matching element in the `rhs` vector and then saving the result at the correct position in the `output` vector. This function is then run across a grid of `BLOCK_SIZE` blocks of threads. The block size is arbitrary here, and is not tuned for the specific GPU hardware this will be run on. The previously-used `foreach()` abstraction will do hardware-specific tuning for this style of dispatch, and is what we recommend for simple elementwise calculations like this. However, this example shows how you might mentally map CUDA C-style code to thread-level GPU operations in MAX. You can execute this one in the same manner as the previous example, by running the Python-based graph that uses the custom op: ```sh pixi run python vector_addition.py ``` ## Conclusion Mojo is an incredible language for programming accelerators: Python-like high-level syntax, systems language performance, and unique language features designed for modern heterogeneous computation. In the examples above, we've introduced the basics of how to write custom ops for MAX graphs, place them in a one-operation graph in Python, and run them on an available CPU or GPU. We showed how to use high-level abstractions like [`foreach()`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach/) to get great performance on GPUs and CPUs, and also how to program GPU-specific functions to control the parallelism yourself. ## Next steps export const docs = [ '../../develop/get-started-with-max-graph-in-python.mdx', '../../develop/custom-ops-matmul.mdx', ]; --- ## Write hardware-agnostic custom ops for PyTorch import InstallModular from '@site/docs/_includes/install-modular.mdx'; import MDXListing from '@site/src/components/Listing/MDXListing'; When working with PyTorch models, you might encounter performance bottlenecks in specific operations that could benefit from custom optimization. You might also want to experiment with novel GPU algorithms or implement cutting-edge research ideas that aren't yet available in standard frameworks. Rather than rewriting your entire model or switching frameworks, you can write high-performance kernels in Mojo and integrate them into your existing PyTorch workflows, enabling both optimization and experimentation in your familiar development environment. This tutorial demonstrates how to enhance a PyTorch model by implementing a custom grayscale image conversion operation in Mojo. You'll discover how to keep your familiar PyTorch development experience while unlocking the performance benefits that Mojo provides for compute-intensive operations. In this tutorial, you'll learn to convert an RGB image into a grayscale image by integrating a custom op in Mojo and running it in PyTorch. ## Set up Let's start by creating a Python project and installing the necessary tools. This will install the `modular` package along with PyTorch (`torch`), PIL for image processing (`pillow`), and NumPy for array operations. When you install the `modular` package, you'll get access to the `max` Python APIs and the Mojo compiler—everything needed to build high-performance custom operations. ## Build the PyTorch interface Let's start by creating the PyTorch side of our integration. We'll build a simple function that uses our custom operation, but first we need to establish the interface. Create a new file called `grayscale.py` and add the following code: ```python title="grayscale.py" from pathlib import Path import torch @torch.compile def grayscale(pic: torch.Tensor): output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension # We'll call our custom operation here return output ``` The `grayscale` function transforms a color image (with red, green, and blue channels) into a single-channel grayscale image. While PyTorch has built-in operations for this, implementing it as a Mojo custom op demonstrates the integration pattern you can apply to any performance-critical computation. Now we need to bridge to our Mojo implementation. ## Integrate the custom operation The [`max.torch`](/max/api/python/torch) module provides [`CustomOpLibrary`](/max/api/python/torch/#max.torch.CustomOpLibrary), which allows you to load and use compiled Mojo operations directly in PyTorch. Update the `grayscale.py` file to include the following code: ```python title="grayscale.py" from pathlib import Path import torch from max.torch import CustomOpLibrary # Load the compiled Mojo package containing our custom operations mojo_kernels = Path(__file__).parent / "operations" ops = CustomOpLibrary(mojo_kernels) @torch.compile def grayscale(pic: torch.Tensor): output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension ops.grayscale(output, pic) # Call our Mojo custom op return output ``` The `CustomOpLibrary` loads operations from a compiled Mojo package file (`.mojopkg`). Once loaded, you can call these operations just like any other PyTorch function. The operations automatically handle data movement between PyTorch and MAX, and integrate with PyTorch when needed. The compilation of your Mojo code into a `.mojopkg` file is handled when you run your Python script. You don't need to manually invoke the Mojo compiler or manage build steps. ## Implement the Mojo kernel Now for the core implementation—our high-performance Mojo kernel. We'll create a Mojo package that defines our grayscale conversion operation Create a new file called `grayscale.mojo` inside the `operations` folder and add the following code: ```mojo title="grayscale.mojo" from compiler import register from max.tensor import InputTensor, OutputTensor, foreach from runtime.asyncrt import DeviceContextPtr from utils.index import IndexList @register("grayscale") struct Grayscale: @staticmethod fn execute[ target: StaticString, ]( img_out: OutputTensor[dtype = DType.uint8, rank=2], img_in: InputTensor[dtype = DType.uint8, rank=3], ctx: DeviceContextPtr, ) raises: @parameter @always_inline fn color_to_grayscale[ simd_width: Int ](idx: IndexList[img_out.rank]) -> SIMD[DType.uint8, simd_width]: @parameter fn load(idx: IndexList[img_in.rank]) -> SIMD[DType.float32, simd_width]: return img_in.load[simd_width](idx).cast[DType.float32]() var row = idx[0] var col = idx[1] # Load RGB values var r = load(IndexList[3](row, col, 0)) var g = load(IndexList[3](row, col, 1)) var b = load(IndexList[3](row, col, 2)) # Apply standard grayscale conversion formula var gray = 0.21 * r + 0.71 * g + 0.07 * b return min(gray, 255).cast[DType.uint8]() foreach[color_to_grayscale, target=target, simd_width=1](img_out, ctx) ``` First, the `@register("grayscale")` decorator makes this operation available to PyTorch under the name `grayscale`. This is the `ops.grayscale(output, pic)` function called in the PyTorch model. Then, the [`foreach`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach/) primitive automatically parallelizes the operation across available compute units. The `color_to_grayscale` function takes an index list and returns a SIMD vector of the grayscale value. The `foreach` primitive automatically parallelizes the operation across available compute units. Finally, the `target` parameter allows the same code to run on both CPU and GPU. This approach allows Mojo to handle the low-level optimization details: memory layout, vectorization, and parallelization, while you focus on the algorithm. ## Run the example Now that we have a PyTorch model and a Mojo kernel, we can test it with a real image. Create a new file called `main.py` to the root of your project and add the following code: ```python title="main.py" import io import numpy as np from PIL import Image import torch from grayscale import grayscale # Specify a test image image = Image.open("test_image.jpg") # Convert to PyTorch tensor and move to GPU for processing image_tensor = torch.from_dlpack(np.array(image)).cuda() # Apply our custom grayscale operation gray_image = grayscale(image_tensor) # Convert back to PIL Image for display or further processing result = Image.fromarray(gray_image.cpu().numpy()) ``` You can run the example with the following command: ```bash python main.py ``` The following image shows the result of the grayscale operation.
This example demonstrates the complete pipeline: downloading an image, converting it to a PyTorch tensor, processing it with our Mojo custom op, and converting back to a standard Python image format. The operation automatically leverages GPU acceleration when available, providing significant performance improvements over CPU-only implementations. ## Next steps Now that you wrote a Mojo kernel with your PyTorch model, check out these other tutorials: export const docs = [ '../../develop/build-custom-ops.mdx', '../../develop/get-started-with-max-graph-in-python.mdx', ]; --- ## Optimize custom ops for GPUs with Mojo import Requirements from '@site/src/components/Requirements'; import { requirementsWithGPU } from '@site/docs/max/requirements'; Building high-performance AI workloads for GPUs can be a daunting task, but Modular simplifies the experience with our custom op API that allows you to build graph operators for both GPUs and CPUs. In this tutorial, we'll teach you some strategies you can use to improve the performance of your GPU custom ops written in Mojo. For demonstration purposes, we'll show you how to incrementally improve the performance of a custom matrix multiplication (matmul) op. We're not teaching you how to build a matmul op, because MAX already contains leading-edge implementations of matmul that you can use with the [MAX graph API](/max/api/python/graph/Graph). Rather, we're using a basic matmul operation (AKA "GPU kernel") to teach you GPU programming strategies that might help with your other GPU code written in Mojo. :::note This page also isn't meant to teach you Mojo. For that, see the [Mojo get started tutorial](/mojo/manual/get-started). ::: As you progress through this tutorial, you'll learn the following: - How to define a custom matrix multiplication operation for a MAX graph. - How to use of Mojo high-performance GPU programming abstractions to progressively optimize a matrix multiplication. - How to access GPU hardware features, like Tensor Cores, from MAX. Let's get started. ## Requirements To use a GPU, your system must meet the [GPU requirements](/max/faq/#gpu-requirements). ## Run and compare the results To get a sense of how each implementation of the custom op (kernel) performs, download the code and run the benchmark script: 1. Get the example code from GitHub: ```bash git clone https://github.com/modular/modular.git ``` 2. If you don't have it, install [`pixi`](https://pixi.sh/latest/): ```sh curl -fsSL https://pixi.sh/install.sh | sh ``` Then restart your terminal for the changes to take effect. 3. Run all the matmul examples: ```bash cd modular/examples/custom_ops pixi run python matrix_multiplication.py ``` As long as you have a compatible GPU, this will compile, run, and print the results of each matmul implementation that we'll discuss below. 4. Run the `matmul` benchmarks to see the impact of each optimization: ```bash pixi run mojo benchmarks.mojo --matmul ``` This also runs each implementation but also benchmarks them and prints a comparison table. For example (this is running on a `g5-2xlarge` instance; your results will vary): ```output --------------------------------------------------------------------------------------------------------- | name | met (ms) | iters | GFLOPS/s | GElems/s | --------------------------------------------------------------------------------------------------------- | cpu/naive | 1647.331583 | 2 | 1.3183084343256959 | 0.0006415126201098278 | | gpu/naive | 2.842315817535545 | 422 | 764.0569378680037 | 0.37180386270949084 | | gpu/coalescing | 1.0952283930000002 | 1000 | 1982.8659792606384 | 0.9648982867448362 | | gpu/tiled | 0.981560302 | 1000 | 2212.48874427279 | 1.076636858526905 | | gpu/tiled_register | 0.39235773577501637 | 3058 | 5534.977195518529 | 2.693419559863031 | | gpu/block_tiled | 0.38266733939393943 | 3135 | 5675.141033565811 | 2.7616258070879858 | | gpu/block_tiled_vectorized | 0.3684924709677419 | 3255 | 5893.447739370804 | 2.8678577807157195 | | gpu/tensor_core | 0.18174263374734928 | 6602 | 11949.266252072646 | 5.814728103198369 | --------------------------------------------------------------------------------------------------------- ``` ## Introduction AI models in [MAX](/max/intro) are built as computational graphs using the [MAX graph API](/max/develop/get-started-with-max-graph-in-python). MAX contains within it a powerful graph compiler that can take these graphs and optimize them for best performance on a wide range of hardware. Each node in a MAX graph is defined by an operation that performs a calculation on zero or more inputs and produces one or more outputs. These inputs and outputs tend to be in the form of tensors, and the operations are usually data-parallel calculations that are accelerated on CPUs or GPUs. In MAX, these operations are written using [Mojo](/mojo/manual/), a Python-family language built for high-performance computation. Matrix multiplications are key components in modern AI models, accounting for a sizable fraction of the GPU workload when running these models. Optimizations applied to matrix multiplication calculations can have a significant impact on the throughput of models on GPUs. To review, a matrix multiplication involves multiplying two matrices, A and B, to produce a new matrix C.
Each value in the output matrix is the dot product of a row from A and a column from B. In a worst case scenario, when multiplying an MxK matrix by a KxN matrix, calculating one output value requires loading `2 * K` values and performing `K` floating-point multiplications. ### Structure of the custom operation The matrix multiplication algorithms demonstrated here are encapsulated within a custom MAX graph operation. AI models in MAX are built from a graph of operations like this, and in this case we're demonstrating one of these operations running in isolation. The `matrix_multiplication.py` file exercises seven different matrix multiplication algorithms using a single-operation graph and shows that the results of multiplying two random matrices are the same for each. These results can be seen by running at the command line using ```sh pixi run python matrix_multiplication.py ``` The single-operation graph is constructed using the following function: ```python def matrix_multiplication( a: NDArray[np.float32], b: NDArray[np.float32], algorithm: str, session: InferenceSession, device: Device, ) -> Buffer: dtype = DType.float32 a_tensor = Buffer.from_numpy(a).to(device) b_tensor = Buffer.from_numpy(b).to(device) mojo_kernels = Path(__file__).parent / "kernels" with Graph( "matrix_multiplication_graph", input_types=[ TensorType( dtype, shape=a_tensor.shape, device=DeviceRef.from_device(device), ), TensorType( dtype, shape=b_tensor.shape, device=DeviceRef.from_device(device), ), ], custom_extensions=[mojo_kernels], ) as graph: a_value, b_value = graph.inputs output = ops.custom( name="matrix_multiplication", device=DeviceRef.from_device(device), values=[a_value, b_value], out_types=[ TensorType( dtype=a_value.tensor.dtype, shape=[a_value.tensor.shape[0], b_value.tensor.shape[1]], device=DeviceRef.from_device(device), ) ], parameters={"algorithm": algorithm}, )[0].tensor graph.output(output) print("Compiling...") model = session.load(graph) print("Executing...") result = model.execute(a_tensor, b_tensor)[0] return result.to(CPU()) ``` A single `matrix_multiplication` operation is used, and the algorithm variant is specified by the `algorithm` compile-time parameter. The custom operation itself is defined in Mojo within the `operations/matrix_multiplication.mojo` file. The `MatrixMultiplication` struct hosts all of the setup code for taking in the matrix tensors, branching execution based on whether the operation is running on CPU or GPU, and then selecting and running a specific algorithm. Mojo supports compile-time specialization of code based on parameters like target hardware, and that is also extended here to user-supplied algorithm choice. Compiling only the code paths used for a particular piece of hardware avoids run-time branching and allows full utilization of an accelerator or CPU. Each algorithm is contained within its own function in `operations/matrix_multiplication.mojo`. Next, we'll discuss how each works. ### Matrix multiplication algorithms The algorithms demonstrated in this example follow steps 1-6 in the progression detailed by Simon Boehm in [his excellent blog post](https://siboehm.com/articles/22/CUDA-MMM) on writing performant matrix multiplications. Each algorithm is represented by a shortened parameter name from the following list: 1. **naive**: Naive matrix multiplication with no optimizations. 2. **coalescing**: Applying memory coalescing. 3. **tiled**: Reworking to use shared memory tiling. 4. **tiled_register**: Using shared memory tiling and register tiling. 5. **block_tiled**: Introducing block tiling. 6. **block_tiled_vectorized**: Block tiling with vectorized memory access. 7. **tensor_core**: Using Tensor Cores for matrix multiplication. The last algorithm is not from Simon's original list, but shows how to access Tensor Core hardware using MAX in Mojo. Each algorithm is meant to show a progressive improvement in performance of matrix multiplication on GPUs. To start with, the impact of each optimization can be seen by running a set of benchmarks against the algorithms in sequence: ```sh pixi run mojo benchmarks.mojo --matmul ``` The results on an NVIDIA A100 GPU for 32-bit floats and input matrices sized to 4096x4096 look like the following at the time this is written: | Algorithm | GFLOPS/s | |------------------------|---------:| | naive | 292 | | coalescing | 2936 | | tiled | 3943 | | tiled_register | 7078 | | block_tiled | 10661 | | block_tiled_vectorized | 10663 | The specific numbers may vary for your GPU, but the general progression should be the same. ### Layouts and LayoutTensor The `matrix_multiplication` custom operation uses [layouts](/mojo/std/layout/layout/) and [`LayoutTensor`](/mojo/std/layout/layout_tensor/LayoutTensor) to represent the input and output matrices, so it's helpful to understand a little bit about these types before getting started. A _layout_ represents a mapping from a set of logical coordinates to a single, one-dimensional coordinate—such as an array index or memory offset. For example, a layout could represent a 2x6, row-major layout: ```mojo my_layout = Layout.row_major(2, 6) print_layout(my_layout) ``` ```plaintext 0 1 2 3 4 5 +----+----+----+----+----+----+ 0 | 0 | 1 | 2 | 3 | 4 | 5 | +----+----+----+----+----+----+ 1 | 6 | 7 | 8 | 9 | 10 | 11 | +----+----+----+----+----+----+ ``` A `LayoutTensor` consists of a layout and a pointer to memory. For example, if you create a `LayoutTensor` using the layout shown above, the value at (1, 1) is stored at memory offset 7. A layout tensor can point to an existing buffer, or you can allocate memory to store the tensor data. One `LayoutTensor` you'll see a lot in the following sections is `tile()`, which returns a new `LayoutTensor` which is a subset of the original, but points to the same underlying data. For example, you can extract a 2x2 tile of the above tensor: ```mojo tile = my_tensor.tile[2, 2](0, 1) ``` The layout of the extracted tile looks like this: ```plaintext 0 1 +----+----+ 0 | 2 | 3 | +----+----+ 1 | 8 | 9 | +----+----+ ``` This just scratches the surface of layouts and `LayoutTensor`, which provide powerful tools for manipulating data and writing parallel algorithms. For more information, see the [Introduction to layouts](/mojo/manual/layout/layouts) and [Using `LayoutTensor`](/mojo/manual/layout/tensors) sections of the Mojo Manual. ## Kernel 1: Naive matrix multiplication with no optimizations The very first algorithm to start with is a "naive" matrix multiplication, one that expresses the problem but makes no attempt at optimizing for how GPUs actually work. In Mojo, a basic matrix multiplication looks like the following: ```mojo fn naive_matrix_multiplication[ dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, BM: Int, BN: Int, ]( a: LayoutTensor[dtype, a_layout, MutAnyOrigin], b: LayoutTensor[dtype, b_layout, MutAnyOrigin], c: LayoutTensor[dtype, c_layout, MutAnyOrigin], ): var M = a.dim[0]() var N = b.dim[1]() var K = b.dim[0]() var row = block_dim.x * block_idx.x + thread_idx.x var col = block_dim.y * block_idx.y + thread_idx.y var dst_reg: c.element_type = 0 if row < UInt(M) and col < UInt(N): for k_index in range(K): dst_reg = dst_reg + a[row, k_index] * b[k_index, col] c[row, col] = dst_reg ``` However, if you glance up at the benchmark table in the previous section, you'll see that the naive matrix multiplication is roughly only 2.7% as fast as sixth algorithm on our list. There's clearly a lot of upside if this core algorithm can be improved. ## Kernel 2: Applying memory coalescing As one quick optimization that has an outsized impact, global memory accesses can be coalesced by swapping the thread indices for columns and rows: ```mojo var row = block_dim.y * block_idx.y + thread_idx.y var col = block_dim.x * block_idx.x + thread_idx.x ``` With this change, adjacent threads access values in the same row of the input matrices, which are contiguous in memory. This leads to an almost tenfold jump in benchmarks on A100. ## Kernel 3: Reworking to use shared memory tiling Shared memory on the GPU is far faster to access than global memory, so a next step is to rework the matrix multiplication to tile the computation and load values into shared memory. The signature for this kernel is the same as the previous two, with the addition of a new `BK` parameter, representing the tile size along the K axis. The input matrices A and B are loaded into shared memory in tiles of size BM x BK and BK x BN, respectively. Within the tile, values are accessed from shared memory, significantly reducing the memory access latency in between arithmetic operations. Since each value in shared memory is used by BK threads (32 in this case), this greatly reduces the number of reads from global memory.
This version corresponds to "Kernel 3: Shared Memory Cache-Blocking" in Simon's blog post. Each thread is still computing a single output value, but it calculates a partial result for each tile worth of input data, and accumulates the partial results to calculate the final value. We'll walk through the interesting part of the kernel. The `LayoutTensor` `tile()` method provides a view to one tile of a tensor, without copying any data. It serves multiple purposes here. First, for the destination tile: ```mojo var col = thread_idx.x % UInt(BN) var row = thread_idx.x // UInt(BN) var dst = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x)) ``` The `dst` value is the tile of the output tensor that the current block is responsible for generating. `dst` is a 32x32 tile of the output tensor, but instead of the 32x32 thread blocks used for previous kernels, this kernel is invoked with a one-dimensional thread block of 32*32 threads. The threads are mapped onto the output tile in row-major order, like this:
As in the previous example, accessing memory in this order is more efficient for the GPU, since it can coalesce adjacent memory accesses for adjacent threads in the same warp. Next the kernel allocates layout tensors in shared memory to hold cached tiles of the input tensors. ```mojo var a_smem = LayoutTensor[ dtype, Layout.row_major(BM, BK), MutAnyOrigin, address_space = AddressSpace.SHARED, ].stack_allocation() var b_smem = LayoutTensor[ dtype, Layout.row_major(BK, BN), MutAnyOrigin, address_space = AddressSpace.SHARED, ].stack_allocation() var dst_reg: c.element_type = 0 ``` The kernel then iterates across the input matrices, loading tiles into shared memory. The `copy_dram_to_sram_async()` function deserves special note. This takes the place of the CUDA pattern of instructing each thread which value or values to copy to shared memory. The `thread_layout` parameter associates individual threads with values, and the function ensures efficient memory copies. ```mojo for block in range(b.dim[0]() // BK): comptime load_a_layout = Layout.row_major(NUM_THREADS // BK, BK) comptime load_b_layout = Layout.row_major(BK, NUM_THREADS // BK) var a_tile = a.tile[BM, BK](Int(block_idx.y), block) var b_tile = b.tile[BK, BN](block, Int(block_idx.x)) copy_dram_to_sram_async[thread_layout=load_a_layout](a_smem, a_tile) copy_dram_to_sram_async[thread_layout=load_b_layout](b_smem, b_tile) async_copy_wait_all() barrier() @parameter for k in range(BK): dst_reg += a_smem[row, k] * b_smem[k, col] barrier() ``` The `async_copy_wait_all()` and `barrier()` calls ensure that all threads have completed their memory copies before proceeding to the next step, the inner loop, which accumulates the partial results for the current output tile. The `@parameter for` construct tells the compiler that this loop can be unrolled at compile time, since `BK` is a parameter, static at runtime. Finally, after all of the tiles have been processed, the results are written to the output tensor. ```mojo dst[row, col] += dst_reg ``` All together, this kernel improves overall performance by ~30% over the previous optimization. :::note While faster on A100, this kernel may not show gains over the previous one on all GPUs. ::: ## Kernel 4: Using shared memory tiling and register tiling Expanding upon the advantages of using shared memory tiling, the partial results can be accumulated in tiled registers and then the final results transferred from there to global memory. In this version, each thread is responsible for calculating multiple values of C, further reducing the memory bandwidth required for each calculation. Specifically, each thread calculates a column of 8 results:
This kernel adds another new parameter to the signature, `TM`, which specifies the size of the register tile. You'll notice the code looks very similar to the previous kernel, except that the results are stored to a register tile. Also, note that the inner loop is structured so that the thread uses a single value from B for all of the calculations for a single register tile. Here's the interesting part of this kernel: ```mojo var col = thread_idx.x % UInt(BN) var row = thread_idx.x // UInt(BN) var dst = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x)).tile[TM, 1]( Int(row), Int(col) ) var a_smem = tb[dtype]().row_major[BM, BK]().shared().alloc() var b_smem = tb[dtype]().row_major[BK, BN]().shared().alloc() var dst_reg = tb[dtype]().layout[TM]().local().alloc() dst_reg.copy_from(dst) for block in range(b.dim[0]() // BK): comptime load_a_layout = Layout.row_major(NUM_THREADS // BK, BK) comptime load_b_layout = Layout.row_major(BK, NUM_THREADS // BK) var a_tile = a.tile[BM, BK](Int(block_idx.y), block) var b_tile = b.tile[BK, BN](block, Int(block_idx.x)) copy_dram_to_sram_async[thread_layout=load_a_layout](a_smem, a_tile) copy_dram_to_sram_async[thread_layout=load_b_layout](b_smem, b_tile) async_copy_wait_all() barrier() @parameter for k in range(BK): var a_tile = a_smem.tile[TM, 1](Int(row), k) var b_tile = b_smem.tile[1, BN](k, 0) var b_val = b_tile[0, col] @parameter for t in range(TM): dst_reg[t] += a_tile[t, 0] * b_val barrier() dst.copy_from(dst_reg) ``` This gives a nearly 80% improvement over the previous tiling implementation. ## Kernel 5: Introducing block tiling We can further increase the arithmetic intensity of the calculation using a 2-D block tiling strategy. In this kernel, each thread is responsible for calculating the output values for a `TM`x`TN` tile of the output tensor. (Where `TM` and `TN` and kernel parameters; in this case we're using 8x8 tiles.) In addition to caching a block's worth of the A & B matrices in shared memory, each thread copies 8-unit vectors of the input matrices into local storage to further reduce memory access latency. Then it uses the [`outer_product_acc()`](/mojo/std/layout/math/outer_product_acc/) function to calculate and accumulate the outer products of the two vectors. ```mojo var partition_col = Int(thread_idx.x % UInt(BN // TN)) var partition_row = Int(thread_idx.x // UInt(BN // TN)) var dst = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x)).tile[TM, TN]( partition_row, partition_col ) var a_smem = tb[dtype]().row_major[BM, BK]().shared().alloc() var b_smem = tb[dtype]().row_major[BK, BN]().shared().alloc() var dst_reg = tb[dtype]().row_major[TM, TN]().local().alloc() dst_reg.copy_from(dst) var a_reg = tb[dtype]().layout[TM]().local().alloc() var b_reg = tb[dtype]().layout[TN]().local().alloc() var ntiles = b.dim[0]() // BK for block in range(ntiles): comptime load_a_layout = Layout.row_major(NUM_THREADS // BK, BK) comptime load_b_layout = Layout.row_major(BK, NUM_THREADS // BK) var a_tile = a.tile[BM, BK](Int(block_idx.y), block) var b_tile = b.tile[BK, BN](block, Int(block_idx.x)) copy_dram_to_sram_async[thread_layout=load_a_layout](a_smem, a_tile) copy_dram_to_sram_async[thread_layout=load_b_layout](b_smem, b_tile) async_copy_wait_all() barrier() @parameter for k in range(BK): var a_tile = a_smem.tile[TM, 1](partition_row, k) var b_tile = b_smem.tile[1, TN](k, partition_col) a_reg.copy_from(a_tile) b_reg.copy_from(b_tile) outer_product_acc(dst_reg, a_reg, b_reg) barrier() dst.copy_from(dst_reg) ``` In the above benchmarks, this provides an additional 50% boost over the previous algorithm. ## Kernel 6: Block tiling with vectorized memory access As a final optimization to our block-tiling kernel, memory accesses can be vectorized to improve memory access bandwidth. The only new thing in this kernel is the use of the [`LayoutTensor.vectorize()`](/mojo/std/layout/layout_tensor/LayoutTensor#vectorize) method to produce vectorized views of the tensors, allowing multiple values to be copied as a single SIMD vector. ```mojo from sys.info import simd_width_of comptime simd_width = simd_width_of[dtype]() var partition_col = Int(thread_idx.x % UInt(BN // TN)) var partition_row = Int(thread_idx.x // UInt(BN // TN)) var dst = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x)).tile[TM, TN]( partition_row, partition_col ) var dst_vec = dst.vectorize[1, simd_width]() var a_smem = tb[dtype]().col_major[BM, BK]().shared().alloc() var b_smem = tb[dtype]().row_major[BK, BN]().shared().alloc() var dst_reg = tb[dtype]().row_major[TM, TN]().local().alloc() var dst_reg_vec = dst_reg.vectorize[1, simd_width]() dst_reg_vec.copy_from(dst_vec) var a_reg = tb[dtype]().layout[TM]().local().alloc() var b_reg = tb[dtype]().layout[TN]().local().alloc() var ntiles = b.dim[0]() // BK for block in range(ntiles): comptime load_a_layout = Layout.row_major(NUM_THREADS // BK, BK) comptime load_b_layout = Layout.row_major(BK, NUM_THREADS // BK) var a_tile = a.tile[BM, BK](Int(block_idx.y), block) var b_tile = b.tile[BK, BN](block, Int(block_idx.x)) copy_dram_to_sram_async[thread_layout=load_a_layout]( a_smem.vectorize[simd_width, 1](), a_tile.vectorize[simd_width, 1]() ) copy_dram_to_sram_async[thread_layout=load_b_layout]( b_smem.vectorize[1, simd_width](), b_tile.vectorize[1, simd_width]() ) async_copy_wait_all() barrier() @parameter for k in range(BK): var a_tile = a_smem.tile[TM, 1](partition_row, k) var b_tile = b_smem.tile[1, TN](k, partition_col) a_reg.copy_from(a_tile) b_reg.copy_from(b_tile) outer_product_acc(dst_reg, a_reg, b_reg) barrier() dst_vec.copy_from(dst_reg_vec) ``` From beginning to end, we've realized more than a 36X improvement in matrix multiplication speed within our MAX custom operation! ## Kernel 7: Using Tensor Cores for matrix multiplication Modern GPUs have dedicated hardware units for performing accelerated matrix multiplication called Tensor Cores. These Tensor Cores can perform matrix multiplications an order of magnitude or more faster than general purpose GPU hardware. However, they can be a challenge to work with. MAX contains interfaces that make it more ergonomic to program these dedicated hardware units. The following is an example of how to perform the same calculation as the above, only on a Tensor Core: ```mojo fn tensor_core_matrix_multiplication[ dtype: DType, layout_a: Layout, layout_b: Layout, layout_c: Layout, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, ]( A: LayoutTensor[dtype, layout_a, MutAnyOrigin], B: LayoutTensor[dtype, layout_b, MutAnyOrigin], C: LayoutTensor[dtype, layout_c, MutAnyOrigin], ): comptime M = C.shape[0]() comptime N = C.shape[1]() comptime K = A.shape[1]() warp_y = warp_id() // UInt(BN // WN) warp_x = warp_id() % UInt(BN // WN) C_warp_tile = C.tile[BM, BN](Int(block_idx.y), Int(block_idx.x)).tile[WM, WN]( Int(warp_y), Int(warp_x) ) mma_op = TensorCore[A.dtype, C.dtype, Index(MMA_M, MMA_N, MMA_K)]() A_sram_tile = tb[A.dtype]().row_major[BM, BK]().shared().alloc() B_sram_tile = tb[B.dtype]().row_major[BK, BN]().shared().alloc() c_reg = ( tb[C.dtype]() .row_major[WM // MMA_M, (WN * 4) // MMA_N]() .local() .alloc() .fill(0) ) for k_i in range(K // BK): barrier() A_dram_tile = A.tile[BM, BK](Int(block_idx.y), k_i) B_dram_tile = B.tile[BK, BN](k_i, Int(block_idx.x)) copy_dram_to_sram_async[thread_layout = Layout.row_major(4, 8)]( A_sram_tile.vectorize[1, 4](), A_dram_tile.vectorize[1, 4]() ) copy_dram_to_sram_async[thread_layout = Layout.row_major(4, 8)]( B_sram_tile.vectorize[1, 4](), B_dram_tile.vectorize[1, 4]() ) async_copy_wait_all() barrier() A_warp_tile = A_sram_tile.tile[WM, BK](Int(warp_y), 0) B_warp_tile = B_sram_tile.tile[BK, WN](0, Int(warp_x)) @parameter for mma_k in range(BK // MMA_K): @parameter for mma_m in range(WM // MMA_M): @parameter for mma_n in range(WN // MMA_N): c_reg_m_n = c_reg.tile[1, 4](mma_m, mma_n) A_mma_tile = A_warp_tile.tile[MMA_M, MMA_K](mma_m, mma_k) B_mma_tile = B_warp_tile.tile[MMA_K, MMA_N](mma_k, mma_n) a_reg = mma_op.load_a(A_mma_tile) b_reg = mma_op.load_b(B_mma_tile) var d_reg_m_n = mma_op.mma_op( a_reg, b_reg, c_reg_m_n, ) c_reg_m_n.copy_from(d_reg_m_n) @parameter for mma_m in range(WM // MMA_M): @parameter for mma_n in range(WN // MMA_N): var C_mma_tile = C_warp_tile.tile[MMA_M, MMA_N](mma_m, mma_n) var c_reg_m_n = c_reg.tile[1, 4](mma_m, mma_n) mma_op.store_d(C_mma_tile, c_reg_m_n) ``` ## Conclusion In this tutorial, we've demonstrated how to create a custom MAX graph operation that performs matrix multiplication using various algorithms and run that on a GPU. We ran benchmarks of each algorithm, showing the performance benefits of various algorithmic improvements. Each improvement was described in detail, showing pathways to get the most speed out of modern GPUs using MAX and Mojo. ## Next Steps - Follow [our tutorial for building a custom operation from scratch](/max/develop/build-custom-ops). - See the [GPU programming](/mojo/manual/gpu/gpu-basics) page in the Mojo manual. - See the [`gpu`](/mojo/std/gpu/) module for detail on Mojo's GPU programming functions and types, and the documentation on [`@compiler.register`](/mojo/manual/decorators/compiler-register/) shows how to register custom graph operations. - Join our [Modular Forum](https://forum.modular.com/) and [Discord community](https://discord.gg/modular) to share your experiences and get support. We're excited to see what you'll build with MAX! Share your projects and experiences with us using `#ModularAI` on social media. --- ## Intro to custom ops Custom operations (custom ops) extend [MAX Graph's Python](/max/model-formats#max-graph) inference APIs with custom [Mojo](/mojo/manual) kernels. Whether you need to optimize performance of functions, implement custom algorithms, or create hardware-specific versions of existing operators, custom ops provide the flexibility you need. The [custom ops](/max/api/python/graph/ops#custom) API provides complete control over MAX Graph while handling kernel integration and optimization pipelines automatically. Try it now with our [custom ops examples](https://github.com/modular/modular/tree/main/max/examples/custom_ops) on GitHub or follow the [Build custom ops for GPUs](/max/develop/build-custom-ops) tutorial and [let us know what you think](https://www.modular.com/community). ### How it works A custom op consists of two main components that work together to integrate your custom implementation into the MAX execution pipeline: 1. A custom function implementation written in Mojo that defines your computation 2. A registration process that connects your function to the graph execution system Under the hood, custom ops utilize high-level abstractions that handle memory management, device placement, and optimization. The graph compiler integrates your custom op implementation into the execution flow. For more information: - Follow the [Build custom ops for GPUs tutorial](/max/develop/build-custom-ops) - Learn more about [GPU programming with Mojo](/mojo/manual/gpu/basics) - Explore the [Custom ops GitHub examples](https://github.com/modular/modular/tree/main/max/examples/custom_ops) - Reference the [MAX Graph custom ops API](/max/api/python/graph/ops#custom) ## Mojo custom ops in PyTorch You can also use Mojo to write high-performance kernels for existing PyTorch models without migrating your entire workflow to MAX. This approach allows you to replace specific performance bottlenecks in your PyTorch code with optimized Mojo implementations. Custom operations in PyTorch can now be written using Mojo, letting you experiment with new GPU algorithms in a familiar PyTorch environment. These custom operations are registered using the [`CustomOpLibrary`](/max/api/python/torch#max.torch.CustomOpLibrary) class in the [`max.torch`](/max/api/python/torch) package. ### How it works 1. Write your kernel implementation in Mojo. 2. Register your custom operation using `CustomOpLibrary` from `max.torch`. 3. Replace specific operations in your existing PyTorch model with your Mojo implementation. This allows you to keep your existing PyTorch workflows while gaining access to Mojo's performance capabilities for targeted optimizations. For more information, see the [Extending PyTorch with custom operations in Mojo](https://github.com/modular/modular/tree/main/max/examples/pytorch_custom_ops) example. --- ## Data types (dtype) Data types (dtypes) define how numbers are stored in tensors. A *dtype* specifies how each element in a tensor is represented in memory, and every tensor has exactly one dtype that applies to all its elements. Choosing the right dtype affects your model's memory usage, numerical precision, and compatibility with different hardware. The [`DType`](/max/api/python/dtype#max.dtype.DType) enum in MAX provides all supported data types: {/* @sync: _examples/dtypes/dtype_intro.py */} ```python from max.dtype import DType # DType is an enum that defines how numbers are stored in tensors # Access dtypes as attributes of the DType class print(DType.float32) # 32-bit floating point print(DType.int32) # 32-bit integer print(DType.bool) # Boolean values ``` Each dtype has three key characteristics: - **Precision**: How accurately numbers are represented (more bits = more precision). - **Range**: The minimum and maximum values that can be stored. - **Memory**: How many bytes each element requires. ## Common dtypes MAX supports all standard NumPy and PyTorch dtypes: | DType | Size | Description | Use case | |------------------|---------|----------------------------------------|---------------------------------------| | `DType.bfloat16` | 2 bytes | 16-bit brain float (8 exp, 7 mantissa) | ML training, better range than fp16 | | `DType.bool` | 1 byte | Boolean true or false | Masks, conditional logic | | `DType.float16` | 2 bytes | 16-bit IEEE floating point | GPU inference, memory savings | | `DType.float32` | 4 bytes | 32-bit IEEE floating point | Default for training and development | | `DType.int32` | 4 bytes | 32-bit signed integer | Indices, counts, discrete values | | `DType.int64` | 8 bytes | 64-bit signed integer | Large indices, token IDs | | `DType.int8` | 1 byte | 8-bit signed integer | Quantized models, extreme compression | For the complete list including float8 variants and all integer types, see the [DType API reference](/max/api/python/dtype#DType). ## Specify dtype when creating tensors When you create a tensor, you can specify its dtype using the `dtype` parameter in the format of `DType.{dtype_name}`: {/* @sync: _examples/dtypes/specify_dtype.py */} ```python from max.driver import CPU from max.dtype import DType from max.tensor import Tensor # Create a tensor with float32 (default for most operations) float_tensor = Tensor.ones([2, 3], dtype=DType.float32, device=CPU()) print(f"Float tensor dtype: {float_tensor.dtype}") # Create a tensor with int32 for indices or counts int_tensor = Tensor.constant([1, 2, 3], dtype=DType.int32, device=CPU()) print(f"Int tensor dtype: {int_tensor.dtype}") ``` The expected output is: ```output Float tensor dtype: DType.float32 Int tensor dtype: DType.int32 ``` In this example, the [`ones()`](/max/api/python/tensor#max.tensor.Tensor.ones) function creates a tensor filled with ones, and the [`constant()`](/max/api/python/tensor#max.tensor.Tensor.constant) function creates a tensor filled with the given values. The `dtype` parameter is used to specify the dtype of the tensor. If you don't specify a dtype, MAX uses: - `float32` for CPU devices. - `bfloat16` for accelerator devices (GPUs). ## Check tensor dtype Every tensor has a [`dtype`](/max/api/python/tensor#max.tensor.Tensor.dtype) property that returns its data type: {/* @sync: _examples/dtypes/check_dtype.py */} ```python from max.driver import CPU from max.dtype import DType from max.tensor import Tensor # Create tensors of different types weights = Tensor.ones([3, 3], dtype=DType.float32, device=CPU()) indices = Tensor.constant([0, 1, 2], dtype=DType.int64, device=CPU()) # Check the dtype of each tensor print(f"Weights dtype: {weights.dtype}") # DType.float32 print(f"Indices dtype: {indices.dtype}") # DType.int64 # Compare dtypes directly if weights.dtype == DType.float32: print("Weights are float32") ``` The expected output is: ```output Weights dtype: DType.float32 Indices dtype: DType.int64 Weights are float32 ``` In this example, the weights tensor is a `float32` tensor and the indices tensor is a `int64` tensor. ## Convert between dtypes The [`cast()`](/max/api/python/tensor#max.tensor.Tensor.cast) method is used to convert a tensor from one dtype to another. This is useful when you need to convert a tensor from a floating-point type to an integer type, or from a higher precision type to a lower precision type. For example: {/* @sync: _examples/dtypes/cast_dtype.py */} ```python from max.driver import CPU from max.dtype import DType from max.tensor import Tensor # Create a float32 tensor x = Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32, device=CPU()) print(f"Original dtype: {x.dtype}") # DType.float32 # Cast to int32 (truncates decimal values) y = x.cast(DType.int32) print(f"After cast to int32: {y.dtype}") # DType.int32 # Cast to float64 for higher precision z = x.cast(DType.float64) print(f"After cast to float64: {z.dtype}") # DType.float64 ``` The expected output is: ```output Original dtype: DType.float32 After cast to int32: DType.int32 After cast to float64: DType.float64 ``` In this example, the original tensor is a `float32` tensor, after casting to `int32`, the tensor is a `int32` tensor, and after casting to `float64`, the tensor is a `float64` tensor. ## DType properties and methods The `DType` enum provides useful properties and methods for inspecting types: {/* @sync: _examples/dtypes/dtype_properties.py */} ```python from max.dtype import DType # Check memory size of different dtypes print(f"float32 size: {DType.float32.size_in_bytes} bytes") # 4 print(f"float32.is_float(): {DType.float32.is_float()}") # True print(f"int32.is_integral(): {DType.int32.is_integral()}") # True print(f"float8_e4m3fn.is_float8(): {DType.float8_e4m3fn.is_float8()}") # True ``` The expected output is: ```output float32 size: 4 bytes float32.is_float(): True int32.is_integral(): True float8_e4m3fn.is_float8(): True ``` For more information, see the [DType API reference](/max/api/python/dtype#DType). ## Interoperability with NumPy and PyTorch tensors MAX provides seamless dtype conversion with NumPy and PyTorch for working with existing data pipelines. ### Use DLPack for tensor conversion DLPack is a standardized in-memory tensor format and protocol that lets array and tensor libraries share data across devices and frameworks with zero or minimal copies. The recommended way to convert NumPy arrays to MAX tensors is through DLPack, which enables zero-copy conversion when possible: {/* @sync: _examples/dtypes/dlpack_conversion.py */} ```python import numpy as np from max.tensor import Tensor # Create a NumPy array np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) # Convert to MAX tensor using DLPack (zero-copy when possible) tensor = Tensor.from_dlpack(np_array) print(f"NumPy dtype: {np_array.dtype}") # float32 print(f"MAX tensor dtype: {tensor.dtype}") # DType.float32 print(f"MAX tensor shape: {tensor.shape}") # [2, 2] ``` The expected output is: ```output NumPy dtype: float32 MAX tensor dtype: DType.float32 MAX tensor shape: [Dim(2), Dim(2)] ``` In this example, the [`from_dlpack()`](/max/api/python/tensor#max.tensor.Tensor.from_dlpack) method converts the NumPy array to a MAX tensor. You can use this method when converting data from other libraries to MAX. :::note NumPy compatibility Not all MAX dtypes have NumPy equivalents. For example, `bfloat16` and `float8` types are not natively supported by NumPy. When working with these types, you may need to cast to a compatible type first. ::: MAX also provides dtype conversion for PyTorch and NumPy integration. The [`from_torch()`](/max/api/python/dtype#max.dtype.DType.from_torch) method converts a PyTorch dtype to a MAX dtype. For example: {/* @sync: _examples/dtypes/pytorch_to_max.py */} ```python import torch from max.dtype import DType # PyTorch tensor pt_tensor = torch.randn(10, 10, dtype=torch.float16) # Convert PyTorch dtype to MAX dtype # API: DType.from_torch(dtype) # dtype: PyTorch dtype # Returns: Corresponding MAX DType # Raises: ValueError if dtype not supported # Raises: RuntimeError if torch not installed max_dtype = DType.from_torch(pt_tensor.dtype) print(f"PyTorch {pt_tensor.dtype} → MAX {max_dtype}") # float16 → DType.float16 ``` The expected output is: ```output PyTorch torch.float16 → MAX DType.float16 ``` Other conversion functions you can use are: - [`to_numpy()`](/max/api/python/dtype#max.dtype.DType.to_numpy): Convert a MAX dtype to a NumPy dtype. - [`to_torch()`](/max/api/python/dtype#max.dtype.DType.to_torch): Convert a MAX dtype to a PyTorch dtype. ## Memory optimization Understanding dtype memory usage is critical for deploying large models. The [`size_in_bytes`](/max/api/python/dtype#max.dtype.DType.size_in_bytes) property lets you calculate exact memory requirements. {/* @sync: _examples/dtypes/calculate_memory.py */} ```python from max.dtype import DType def calculate_memory(shape: list[int], dtype: DType) -> int: """Calculate memory usage in bytes for a tensor.""" # API: dtype.size_in_bytes # Returns: Size of dtype in bytes (int) num_elements = 1 for dim in shape: num_elements *= dim bytes_used = num_elements * dtype.size_in_bytes return bytes_used # Compare dtypes for same tensor shape = [1024, 1024, 1024] # 1B elements float32_mb = calculate_memory(shape, DType.float32) / (1024**2) float16_mb = calculate_memory(shape, DType.float16) / (1024**2) int8_mb = calculate_memory(shape, DType.int8) / (1024**2) print(f"float32: {float32_mb:.1f} MB") # 4096.0 MB print(f"float16: {float16_mb:.1f} MB") # 2048.0 MB (50% reduction) print(f"int8: {int8_mb:.1f} MB") # 1024.0 MB (75% reduction) ``` ## Type validation Use dtype checking methods to write code that validates inputs at runtime. For example: {/* @sync: _examples/dtypes/validate_dtypes.py */} ```python from max.dtype import DType def validate_weights_dtype(dtype: DType) -> None: """Ensure weights use a floating-point type.""" # API: dtype.is_float() # Returns: True if dtype is any floating-point type if not dtype.is_float(): raise TypeError(f"Weights must be float type, got {dtype}") def validate_indices_dtype(dtype: DType) -> None: """Ensure indices use an integer type.""" # API: dtype.is_integral() # Returns: True if dtype is any integer type (signed or unsigned) if not dtype.is_integral(): raise TypeError(f"Indices must be integer type, got {dtype}") # Usage weights_dtype = DType.float16 indices_dtype = DType.int32 validate_weights_dtype(weights_dtype) # OK validate_indices_dtype(indices_dtype) # OK ``` ## Next steps Now that you understand dtypes, continue learning: {/* - [Basic operations](/max/develop/basic-ops) - Use dtypes in tensor operations */} - [Building graphs](/max/develop/get-started-with-max-graph-in-python): Specify dtypes in computation graphs. - [Quantization](/max/graph/quantize): Quantize weights to reduce memory usage and improve performance. --- ## Get started with MAX graphs import InstallModular from '@site/docs/_includes/install-modular.mdx'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; MAX provides a high-performance computation framework that lets you build and execute efficient machine learning models. It provides a flexible way to define computational workflows as graphs, where each node represents an operation (like matrix multiplication or addition) and edges represent the flow of data. By using the MAX Python API, you can create optimized machine learning models that run faster and more efficiently on modern hardware. In this tutorial, you'll build a graph using the Python [`Graph`](/max/api/python/graph/Graph) API with an [`ops` function](/max/api/python/graph/ops). To do this, you will complete the following steps: 1. [Build a simple graph that adds two numbers](#build-the-graph) 2. [Create an inference session to load and compile the graph](#create-inference-session) 3. [Execute the graph with input data](#execute-the-graph) By the end of this tutorial, you'll have an understanding of how to construct basic computational graphs, set up inference sessions, and run computations using the MAX Python API. ## Set up your environment Create a Python project to install our APIs and CLI tools. Then, create a working directory. Create a folder called `max_ops`: ```sh mkdir max_ops cd max_ops ``` You can check your MAX version like this: ```sh max --version ``` You can check your Python version like this: ```sh python --version ``` Create a folder called `max_ops`: ```sh mkdir max_ops cd max_ops ``` You can check your MAX version like this: ```sh max --version ``` You can check your Python version like this: ```sh python --version ``` Change folders to your working directory: ```sh cd src/quickstart ``` You can check your MAX version like this: ```sh pixi run max --version ``` You can check your Python version like this: ```sh pixi run python --version ``` :::tip To clear cached data while iterating on graph builds, you can use `pixi clean` to remove the MEF cache and other environment data: ```sh pixi clean ``` This removes the entire pixi environment, so you'll need to reinstall packages afterward. ::: If you have any questions along the way, ask them on [our Discord channel](https://discord.gg/modular). ## 1. Build the graph {#build-the-graph} Now with our environment and packages setup, lets create the graph. This graph will define a computational workflow that adds two tensors together. Let's start by creating a new file called `addition.py` inside of your working directory and add the following libraries: ```python import numpy as np from max import engine from max.driver import CPU, Buffer from max.dtype import DType from max.graph import DeviceRef, Graph, TensorType, ops ``` To create a computational graph, use the [`Graph()`](/max/api/python/graph/Graph) class from the MAX Python API. When initializing, specify a name for the graph and define the types of inputs it will accept. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> np.ndarray: # 1. Build the graph input_type = TensorType( dtype=DType.float32, shape=(1,), device=DeviceRef.CPU() ) with Graph( "simple_add_graph", input_types=(input_type, input_type) ) as graph: lhs, rhs = graph.inputs out = ops.add(lhs, rhs) graph.output(out) ``` Inside the context manager, access the graph's inputs using the [`inputs`](/max/api/python/graph/Graph#max.graph.Graph.inputs) property. This returns a symbolic tensor representing the input arguments. The symbolic tensor is a placeholder that represents the shape and type of data that will flow through the graph during the execution, rather than containing the actual numeric values like in eager execution. Then use the [`add()`](/max/api/python/graph/ops#max.graph.ops.add) function from the [`ops`](/max/api/python/graph/ops) package to add the two input tensors. This creates a new symbolic tensor representing the sum. Finally, set the output of the graph using the [`output()`](/max/api/python/graph/Graph#max.graph.Graph.output) method. This specifies which tensors should be returned when the graph is executed. Now, add a `print()` function to the graph to see what's created. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, any]: # 1. Build the graph # ... print("final graph:", graph) ``` The output will show us the structure of our graph, including the input it expects and the operations it will perform. This helps us understand how our graph will process data when we use it. Next, let's load the graph into an inference session. ## 2. Create an inference session {#create-inference-session} Now that our graph is constructed, let's set up an environment where it can operate. This involves creating an inference session and loading our graph into it. Create an [`InferenceSession()`](/max/api/python/engine#max.engine.InferenceSession) instance that loads and runs the graph inside the `add_tensors()` function. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, any]: # 1. Build the graph # ... # 2. Create an inference session session = engine.InferenceSession(devices=[CPU()]) model = session.load(graph) ``` This step transforms our abstract graph into a computational model that's ready for execution. :::tip Debugging graph compilation errors If you encounter errors during `session.load(graph)` (graph compilation), you can enable detailed debugging information by setting the `MODULAR_MAX_DEBUG` environment variable: ```bash export MODULAR_MAX_DEBUG=True python addition.py ``` This provides detailed stack traces for graph lowering failures. This can help you diagnose the problem and fix it but it does make the graph creation slower and should only be used when debugging compilation errors. ::: To ensure our model is set up correctly, let's examine its input requirements. Print the graph's input metadata by using the [`input_metadata`](/max/api/python/engine#max.engine.Model.input_metadata) property. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, any]: # 1. Build the graph # ... # 2. Create an inference session session = engine.InferenceSession(devices=[CPU()]) model = session.load(graph) # highlight-start for tensor in model.input_metadata: # highlight-end print( f"name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}" ) ``` This will output the exact specifications of the input our model expects, helping us prepare appropriate data for processing. Next, let's execute the graph. ## 3. Execute the graph {#execute-the-graph} To give the model something to add, create two inputs of a shape and a data type that match our graph's input requirements. Then pass the inputs to the [`execute()`](/max/api/python/engine#max.engine.Model.execute) function: ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, any]: # ... # 2. Create an inference session # ... # 3. Execute the graph # highlight-start output = model.execute(a, b)[0] result = output.to_numpy() # highlight-end return result ``` Notice that the [`execute()`](/max/api/python/engine#max.engine.Model.execute) function returns a list of outputs. We know we want the first element and must assert that the type is [`Buffer`](/max/api/python/driver#max.driver.Buffer) so we can convert it to a NumPy array. ## 4. Run the example Now that we've built our graph, created an inference session, and defined how to execute the graph, let's put it all together and run our complete example. At the end of your `addition.py` file, add the following code: ```python if __name__ == "__main__": input0 = np.array([1.0], dtype=np.float32) input1 = np.array([1.0], dtype=np.float32) result = add_tensors(input0, input1) print("result:", result) ``` This passes your arguments `input0` and `input1` to the `add_tensors()` function. Then, run the Python file from the command line: ```sh python addition.py ``` ```sh python addition.py ``` ```sh pixi run python addition.py ``` You've successfully created your first graph using the MAX Python API. Let's examine what was printed to the terminal: ```output final graph: mo.graph @simple_add_graph(%arg0: !mo.tensor<[1], f32>, %arg1: !mo.tensor<[1], f32>) -> !mo.tensor<[1], f32> attributes {argument_names = ["input0", "input1"], result_names = ["output0"]} { %0 = rmo.add(%arg0, %arg1) : (!mo.tensor<[1], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[1], f32> mo.output %0 : !mo.tensor<[1], f32> } ``` - Two input tensors (`%arg0`, `%arg1`) of shape `[1]` and float32 type - The addition operation connecting them - One output tensor of matching shape/type The metadata lines confirm both input tensors match the required specifications. ```output name: input0, shape: [1], dtype: DType.float32 name: input1, shape: [1], dtype: DType.float32 ``` The result shows the addition worked correctly: $$ [1.0] + [1.0] = [2.0] $$ ```output result: [2.] ``` Now that you've built your first MAX graph that performs addition, you can explore more complex examples: - [MAX graph API example](https://github.com/modular/modular/tree/main/max/examples/max-graph) - [MAX graph implementation of Llama3](https://github.com/modular/modular/tree/main/max/python/max/pipelines/architectures) ## Next steps export const docs = [ '../../develop/build-custom-ops.mdx', '../../develop/serve-custom-model-architectures.mdx', '../../develop/build-an-mlp-block.mdx', ]; --- ## Model developer guide MAX is a high-performance framework built for production-ready neural network model development and deployment across a wide variety of hardware. It supports two programming patterns for constructing an AI model: - *Eager-style execution* for an enhanced developer experience during model development and debugging. - *Explicit graph construction* when you need low-level control over compilation and deployment. The eager-style execution pattern provides a familiar developer experience inspired by PyTorch's eager execution, allowing you to write natural Python code with familiar operators and syntax. Since most time spent in model development is verifying model correctness, we recommend starting with eager-style execution. You write PyTorch-style code and get feedback on shape and type errors during development, while MAX uses lazy evaluation to automatically build and optimize computation graphs behind the scenes, delivering better performance than PyTorch. Take, for example, the following code: {/* @sync: _examples/index/eager_simple.py */} ```python from max import functional as F from max.driver import CPU from max.tensor import Tensor # Create tensor from Python data x = Tensor.constant([1.0, -2.0, 3.0, -4.0, 5.0], device=CPU()) y = F.relu(x) # Results are available right away print(y) ``` The expected output is: ```output Tensor([1 0 3 0 5], dtype=DType.float32, device=Device(type=cpu,id=0)) ``` When you run this code, [`Tensor.constant()`](/max/api/python/tensor#max.tensor.Tensor.constant) creates a tensor and [`F.relu()`](/max/api/python/functional#max.functional.relu) performs the ReLU activation. The `print(y)` statement triggers execution. MAX compiles and runs the operations, then displays the result. :::note Eager-style execution performance improvements are in progress. Initial compilation times may be higher than expected and should not be used as a general replacement for numpy in production. ::: This is different from explicit graph construction where you define the complete graph structure upfront, compile it separately, then execute it with data. Here's the same `relu()` operation using explicit graph construction: {/* @sync: _examples/index/graph_simple.py */} ```python import numpy as np from max.driver import CPU from max.dtype import DType from max.engine import InferenceSession from max.graph import Graph, TensorType, ops # Step 1: Define the graph structure cpu = CPU() input_type = TensorType(DType.float32, shape=[5], device=cpu) with Graph("relu_graph", input_types=[input_type]) as graph: x = graph.inputs[0] y = ops.relu(x) graph.output(y) # Step 2: Compile the graph session = InferenceSession(devices=[cpu]) model = session.load(graph) # Step 3: Execute with data input_data = np.array([1.0, -2.0, 3.0, -4.0, 5.0], dtype=np.float32) result = model.execute(input_data) print(np.from_dlpack(result[0])) ``` The expected output is: ```output [1. 0. 3. 0. 5.] ``` Both produce the same result, but explicit graph construction gives you full control over graph structure, data flow, and device placement. Eager-style execution lets you write natural Python code while MAX handles graph building and optimization automatically. ## Use standard Python operators You can perform arithmetic operations using Python operators on tensors: {/* @sync: _examples/index/arithmetic_ops.py */} ```python from max.tensor import Tensor a = Tensor.constant([1.0, 2.0, 3.0]) b = Tensor.constant([4.0, 5.0, 6.0]) c = a + b # Addition d = a * b # Element-wise multiplication print(c) print(d) ``` For operations beyond basic arithmetic, use the functional API: {/* @sync: _examples/index/functional_api.py */} ```python from max import functional as F from max.driver import CPU from max.tensor import Tensor # Force CPU execution to avoid GPU compiler issues x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]], device=CPU()) y = F.sqrt(x) # Element-wise square root z = F.softmax(x, axis=-1) # Softmax along last axis print(f"Input: {x}") print(f"Square root: {y}") print(f"Softmax: {z}") ``` The [`max.functional`](/max/api/python/functional) module (typically imported as `F`) provides operations like [`relu()`](/max/api/python/functional#max.functional.relu), [`softmax()`](/max/api/python/functional#max.functional.softmax), [`sqrt()`](/max/api/python/functional#max.functional.sqrt), and many more. {/* For more information, see [Basic operations](/max/develop/basic-ops). */} ## Inspect values while debugging One of the biggest advantages of the eager-style API is that you can inspect intermediate values at any point in your code. This makes debugging straightforward: {/* @sync: _examples/index/when_to_use_eager.py */} ```python from max import functional as F from max.driver import CPU from max.dtype import DType from max.tensor import Tensor def debug_forward_pass(x: Tensor) -> Tensor: """Forward pass with intermediate inspection.""" # Can print/inspect at any point print(f"Input: {x}") z = x * 2 print(f"After multiply: {z}") h = F.relu(z) print(f"After ReLU: {h}") return h x = Tensor.constant([-1.0, 0.0, 1.0, 2.0], dtype=DType.float32, device=CPU()) result = debug_forward_pass(x) ``` The expected output is: ```output Input: Tensor([-1 0 1 2], dtype=DType.float32, device=Device(type=cpu,id=0)) After multiply: Tensor([-2 0 2 4], dtype=DType.float32, device=Device(type=cpu,id=0)) After ReLU: Tensor([0 0 2 4], dtype=DType.float32, device=Device(type=cpu,id=0)) ``` If shapes don't match or types are incompatible, you get clear errors showing exactly which operation failed, right where you wrote it. ## Follow a complete workflow example Here's a more realistic example showing a forward pass through a simple model: {/* @sync: _examples/index/eager_workflow.py */} ```python from max import functional as F from max import random from max.driver import CPU from max.dtype import DType from max.tensor import Tensor # Create input data x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]], dtype=DType.float32, device=CPU()) # Create random weights w = random.gaussian( [2, 2], mean=0.0, std=0.1, dtype=DType.float32, device=CPU() ) # Forward pass - each operation executes as you write it z = x @ w # Matrix multiply h = F.relu(z) # Activation out = h.mean() # Reduce to scalar # Inspect intermediate results anytime print(f"Input shape: {x.shape}") print(f"After matmul: {z.shape}") print(f"Output: {out}") ``` Shape and type validation happens as you write operations. If the matrix dimensions didn't align for `x @ w`, you'd get an error at that exact line showing the shape mismatch. Try modifying the code to see how validation errors appear at the line that caused them. ## Understand when execution happens With eager-style execution, MAX uses lazy evaluation to optimize performance. When you write operations like `y = F.relu(x)`, MAX doesn't compute the result immediately. Instead, it builds up a computation graph by recording each operation as you call it. This deferred execution allows MAX to analyze the entire sequence of operations and optimize them before running anything. Think of it like writing a recipe: as you add each step (slice vegetables, heat oil, add ingredients), you're building instructions but not actually cooking yet. MAX does the same thing, it records your tensor operations into a graph, then compiles and executes that graph only when you need actual computed values. Execution happens automatically when an operation requires concrete values: - **Printing tensors**: `print(x)` needs real values to display them. MAX compiles and executes the graph to produce those values. - **Accessing scalar values**: [`x.item()`](/max/api/python/tensor#max.tensor.Tensor.item) must return a concrete Python number (like `3.14`) rather than a symbolic operation, which triggers execution. - **Indexing tensors**: `x[0]` needs to extract the actual value at that position. - **Converting to other formats**: Passing a MAX `Tensor` to `np.from_dlpack(x)` requires a real memory buffer for `x`, which forces execution. - **Module forward passes**: Calling `module(x)` or `module.forward(x)` on an [`nn.Module`](/max/api/python/nn#max.nn.Module) executes the graph to produce the output. This all happens automatically and transparently behind the scenes. MAX detects when it needs actual values and handles compilation and execution for you. ## Next steps Now that you understand the basic programming patterns in MAX, continue learning about the core tensor concepts that apply to both eager-style execution and explicit graph construction: - **[Tensor fundamentals](/max/develop/tensors)**: Learn what tensors are, how to create them, and how to inspect their properties like shape, dtype, and device. - **[Data types (dtype)](/max/develop/dtypes)**: Understand how dtypes control tensor precision and memory usage, and how to convert between MAX, NumPy, and PyTorch types. --- ## Bring your own fine-tuned model to MAX pipelines import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; In this tutorial, you'll integrate a fine-tuned custom model into MAX pipelines. More specifically, we will start with the initial configuration and then demonstrate how to download a model from the Hugging Face Hub. If the model is not already available in a supported quantized GGUF format, we'll show you how to convert it to prepare for ingestion into the MAX pipelines. Finally, we will explore how to use the quantized GGUF model via the MAX pipelines CLI. ## About model customization Model customization in machine learning typically involves modifying a pre-trained model to better suit specific tasks or datasets. One effective approach is fine-tuning, where a model trained on a large dataset is further trained (or fine-tuned) on a smaller, task-specific dataset. In this tutorial, we focus on [Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685). LoRA (and its quantized variant [QLoRA](https://arxiv.org/abs/2305.14314)) allows for efficient adaptation of large models by only updating a small set of additional parameters, preserving the original model's structure by integrating LoRA layers without altering the primary architecture. For this tutorial, we are assuming the LoRA weights have been merged into the original model such as **Llama3.1**. If you want to serve a LoRA adapter with MAX without merging the weights into your base model, see the [LoRA adapters](/max/serve/lora-adapters) guide. Such a functionality is provided by major fine-tuning libraries such as [unsloth `save_pretrained_merged`](https://docs.unsloth.ai/basics/saving-models/saving-to-gguf) or using [PEFT model merging](https://huggingface.co/docs/peft/en/developer_guides/model_merging) APIs. ## Step 1: Set up Hugging Face access To interact with models hosted on Hugging Face, secure access is required either via SSH or an access token. Follow the instructions in the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-git-ssh) to set up SSH. We can verify our configuration by running: ```sh ssh -T git@hf.co ``` A successful setup will display `Hi , welcome to Hugging Face`. ## Step 2: Set up MAX pipelines Next is to clone the [MAX GitHub repository](https://github.com/modular/modular) and navigate to the MAX pipeline: ```sh git clone -b stable https://github.com/modular/modular && cd max cd src/max ``` ## Step 3: Include the `huggingface_hub` CLI We'll use the `pixi` CLI to create a virtual environment and install the required packages. If you don't have `pixi`, you can install it with this command: ```sh curl -fsSL https://pixi.sh/install.sh | sh ``` Now install the `huggingface_hub` library to enable interactions with the Hugging Face Hub. This package facilitates the download, and management of models and datasets: ```sh pixi add --pypi huggingface_hub hf_transfer ``` With the Hugging Face Hub CLI installed, we can proceed to the next steps of downloading and converting our model. ## Step 3: Convert to GGUF format If your model is already in the [GGUF format](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md), you can skip this conversion step and proceed directly to the next step. If not, here are the most common methods to convert a model to a quantized GGUF format suitable for deployment: - **Automated conversion via Hugging Face space**: We can use the [gguf-my-repo](https://huggingface.co/spaces/ggml-org/gguf-my-repo) space for a streamlined conversion process to convert to a supported quantized GGUF format. Remember to log in and for the sake of this tutorial, we choose the `Q4_K_M` quantization method. You can see all the supported quantization encodings in the [`QuantizationEncoding` module](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding). For demonstration, we will choose [mlabonne/FineLlama-3.1-8B](https://huggingface.co/mlabonne/FineLlama-3.1-8B). After conversion, the model will be available under your HugginFace USERNAME, ready for download and deployment. ![](images/max-pipeline-bring-your-own-model/gguf-my-repo.png) The following will download the converted GGUF model: ```sh pixi run huggingface-cli download \ /FineLlama-3.1-8B-Q4_K_M-GGUF \ --repo-type model \ --local-dir ./models ``` - **Manually convert via llama.cpp script**: Alternatively, utilize the [llama.cpp converter script](https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py) to manually convert your model. ```sh git clone https://github.com/ggerganov/llama.cpp # If your model is available in Hugging Face . # Ensure you replace with the appropriate # repository or model ID from Hugging Face. # Otherwise skip this command. pixi run huggingface-cli download \ --repo-type model \ --local-dir ./models python llama.cpp/convert_hf_to_gguf.py models ``` With all the requirements in place we are now ready to use our custom model in MAX pipelines. ## Step 4: Run the custom model With our fine-tuned Llama 3.1 model successfully converted to GGUF format, we're ready to put it into action using MAX pipelines. For this demonstration, we'll be using our converted model file `finellama-3.1-8b-q4_k_m.gguf`. First, let's install the necessary CLI tool. MAX provides the `max` package, which we can easily install using the `pixi` command: ```bash pixi global install modular -c https://conda.modular.com/max-nightly -c conda-forge ``` Before running our model, it's worth noting that MAX pipelines offer various configuration options. You can explore these by running, `max --help` for the available options. :::note If you use private or gated models, you must set your [Hugging Face access token](https://huggingface.co/docs/hub/en/security-tokens) first. For example: ```bash export HF_TOKEN="hf_..." ``` Then you can run a `max` command to execute a private or gated model. ::: Now, let's run our custom model. We'll use the `max generate` command, specifying our model configuration and a test prompt: ```bash max generate \ --model modularai/Llama-3.1-8B-Instruct-GGUF \ --quantization-encoding "q4_k" \ --weight-path "./models/finellama-3.1-8b-q4_k_m.gguf" \ --prompt "What is the meaning of life?" ``` It generates the following answer: ```output The meaning of life is a question that has been pondered by philosophers, scientists, and spiritual leaders for centuries. It is a question that has no definitive answer, as it is deeply personal and subjective to each individual. However, many have attempted to provide their own interpretations or explanations. One interpretation of the meaning of life is that it is simply to live and experience the world around us. This view suggests that the purpose of life is to experience all that it has to offer, whether it be through the senses, emotions, or intellectual pursuits. In this sense, the meaning of life is not necessarily tied to any specific goal or achievement, but rather to the process of living itself. Another interpretation is that the meaning of life is to find purpose and meaning in our lives. This view suggests that we are here to seek out our own unique purpose and to strive to achieve it. This can be achieved through various means, such as through our work, relationships, or personal pursuits. A third interpretation is that the meaning of life is to connect with something larger than ourselves. This view suggests that we are here to connect with a higher power, whether it be through religion, spirituality, or a sense of awe and wonder at the universe. In this sense, the meaning of life is to find a sense of purpose and connection that transcends our individual lives. Ultimately, the meaning of life is a question that each person must answer for themselves. It is a question that requires us to reflect on our own values, beliefs, and experiences. As the saying goes, "Ask a flower" - the meaning of life is not something that can be answered in words, but rather in the experience of living itself. ``` For more information on quantization, see the [Quantization](/max/graph/quantize) documentation. ## Next steps Congratulations on successfully integrating your fine-tuned Llama3.1 model into the MAX pipelines! 🎉 We have navigated through setting up secure access, downloading and converting models, and finally running your custom model in MAX pipelines. We encourage you to further customize your models via the MAX Graph API, test your pipeline and explore other MAX features including how to **deploy your fine-tuned model on GPU using MAX Serve**. Here are some other topics to explore next: export const docs = [ '../../develop/serve-custom-model-architectures.mdx', '../../develop/custom-ops.mdx', '../../graph/quantize.mdx', ]; --- ## Model pipeline import MDXListing from '@site/src/components/Listing/MDXListing'; When you build models in MAX, whether it's a GPT, Llama, or a custom architecture, you're creating the layers, attention mechanisms, and forward pass logic that define how the model processes inputs. But to actually serve that model as an endpoint that can handle production requests, you need to connect it to MAX's serving infrastructure. That's where an **inference pipeline** comes in. A **pipeline** is the bridge between your model and MAX's serving framework. The pipeline performs any pre- and post-processing for the model and orchestrates the inference workflow. For example, the pipeline loads model weights, manages key-value caches, batches requests, and calls your tokenizer to encode/decode the inputs/outputs. You can use the [`pipelines`](/max/api/python/pipelines/) API to make any model architecture compatible with MAX, whether you're adapting an existing model or implementing a new one from scratch. The pipeline system uses a registry pattern where model architectures register their capabilities, and the infrastructure handles the execution details. When you point MAX at a model, the registry looks up the architecture, validates compatibility, downloads weights, compiles the model, and returns a ready-to-use pipeline. This architecture separates concerns cleanly: - **Modules** define model architectures and hold weights. - **Pipelines** orchestrate the inference loop and manage state. - **Registry** maps model identifiers to implementations. - **Compilation** transforms your model into optimized executables for the target device. Pipelines let you focus on model architecture while MAX handles the production infrastructure—batching, caching, compilation, and serving. ## Building blocks Before diving into pipeline components, it's helpful to understand the two foundational packages that pipelines build on: [`max.nn`](/max/api/python/nn) and [`max.kv_cache`](/max/api/python/kv_cache). ### Neural network module The [`max.nn`](/max/api/python/nn) (neural network) package provides reusable neural network layers that serve as the bridge between the MAX Graph API and model implementations. The `max.nn` package includes common components like: - **Core modules**: [`Module`](/max/api/python/nn#max.nn.Module), [`Linear`](/max/api/python/nn#max.nn.Linear), [`Embedding`](/max/api/python/nn#max.nn.Embedding), [`Sequential`](/max/api/python/nn#max.nn.Sequential), [`ModuleList`](/max/api/python/nn#max.nn.ModuleList). - **Normalization**: [`RMSNorm`](/max/api/python/nn#max.nn.RMSNorm), [`GemmaRMSNorm`](/max/api/python/nn#max.nn.GemmaRMSNorm). - **Positional encodings**: [`RotaryEmbedding`](/max/api/python/nn#max.nn.RotaryEmbedding), [`TransposedRotaryEmbedding`](/max/api/python/nn#max.nn.TransposedRotaryEmbedding). - **Utilities**: [`module_dataclass`](/max/api/python/nn#max.nn.module_dataclass) for creating module dataclasses. These components are core to building model architectures. :::note Legacy modules For legacy layer-based components like attention mechanisms and transformers, use `max.nn.legacy`: ```python from max.nn.legacy import Module, Layer, Linear from max.nn.legacy.attention import AttentionWithRope ``` ::: The [`Module`](/max/api/python/nn#max.nn.Module) base class standardizes how layers manage weights and devices. Here's an example of building a simple multi-layer perceptron: ```python from max.driver import Accelerator from max.nn import Module, Linear class MLP(Module): fc1: Linear fc2: Linear def forward(self, x): return self.fc2(self.fc1(x)) # Create a model with two linear layers model = MLP(fc1=Linear(10, 20), fc2=Linear(20, 5)) # Weights are tracked automatically through the module hierarchy for name, param in model.parameters: print(f"{name}: {param.shape}") # fc1.weight: [20, 10] # fc1.bias: [20] # fc2.weight: [5, 20] # fc2.bias: [5] # Move all parameters to an accelerator (GPU) model.to(Accelerator()) ``` In this example, the `Module` base class automatically tracks all parameters through the module hierarchy, letting you iterate over them or inspect them. The `to()` method provides a simple way to move the entire model and all its parameters to a different device with a single call. ### KV cache module The [`max.kv_cache`](/max/api/python/kv_cache) package provides cache management for transformer inference. The main component is [`PagedKVCacheManager`](/max/api/python/kv_cache/#max.kv_cache.PagedKVCacheManager), which handles memory allocation for key-value pairs across generation steps. For most use cases, you don't interact with the cache manager directly. The pipeline handles cache management automatically using paged attention based on the [`supported_encodings`](/max/api/python/pipelines/registry/#max.pipelines.lib.registry.SupportedArchitecture.supported_encodings) in your architecture config. ### How modules and pipelines work together Understanding the relationship between modules and pipelines is key to working with MAX. When you build a model for the pipeline system, you define your architecture using the [`Module`](/max/api/python/nn/#max.nn.Module) class. MAX then compiles your module into an optimized executable for the target device, and the pipeline orchestrates execution. Here's the workflow: 1. **Define your model**: You create a [`Module`](/max/api/python/nn#max.nn.Module) that defines your model architecture. The module's [`forward`](/max/api/python/nn#max.nn.Module.forward) method defines what computations happen when processing inputs. 2. **MAX compiles your model**: MAX compiles the module into an optimized executable for the target device. This compilation happens once, and you can reuse the result for many inference calls. 3. **Pipelines orchestrate execution**: The pipeline receives pre-tokenized context objects (which the tokenizer creates), manages the KV cache, calls your model's [`PipelineModel.execute()`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.execute) method, samples output tokens, and returns results. This separation lets you work at the right level of abstraction: use [`max.nn`](/max/api/python/nn) to define model architectures, let MAX handle compilation, and rely on pipelines for production serving. ### Request execution example To see how these components work together in practice, here's a complete example of generating text with a pipeline. The [`PIPELINE_REGISTRY`](/max/api/python/pipelines/registry/) is the central system that maps model architectures to their compiled pipelines and tokenizers. When you call [`retrieve()`](/max/api/python/pipelines/registry/#max.pipelines.lib.registry.PipelineRegistry.retrieve) with a [`PipelineConfig`](/max/api/python/pipelines/config/), the registry looks up the model's architecture from its Hugging Face config, validates compatibility, downloads weights if needed, compiles the model, and returns both a tokenizer and a ready-to-use pipeline instance. ```python import asyncio from max.interfaces import RequestID, TextGenerationInputs, TextGenerationRequest from max.pipelines import PIPELINE_REGISTRY, PipelineConfig from max.pipelines.core import TextContext # 1. Configure and retrieve the pipeline and tokenizer config = PipelineConfig( model_path="meta-llama/Llama-3.1-8B-Instruct", max_length=512, ) tokenizer, pipeline = PIPELINE_REGISTRY.retrieve(config) # 2. Get the KV cache manager from the pipeline kv_cache_manager = pipeline.kv_managers[0] # 3. Create a request with the text prompt request = TextGenerationRequest( request_id=RequestID(), model_name=config.model_path, prompt="Explain how neural networks work", ) # 4. Create a context object (tokenization happens here) context = asyncio.run(tokenizer.new_context(request)) # 5. Allocate space in the KV cache for this request kv_cache_manager.claim(context.request_id) # 6. Run the generation loop generated_text = "" while True: # Allocate KV cache for the next token kv_cache_manager.alloc(context, num_steps=1) # Execute the pipeline with the current context inputs = TextGenerationInputs[TextContext]( batches=[{context.request_id: context}], num_steps=1, ) output = pipeline.execute(inputs) # Decode and accumulate generated tokens for token in output[context.request_id].tokens: generated_text += asyncio.run( tokenizer.decode(token, skip_special_tokens=True) ) # Check if generation is complete if output[context.request_id].is_done: break print(generated_text) ``` In this example, you can see the key phases of pipeline execution: 1. The registry maps the model path to the appropriate tokenizer and compiled pipeline. 2. The cache manager tracks memory allocation for the request's key-value pairs across generation steps. 3. The `new_context()` method handles tokenization internally and creates a context object that tracks the request's state throughout generation. 4. The pipeline processes tokens, the model executes, and the sampler selects new tokens until completion. Notice how the pipeline itself is stateless, all request-specific state lives in the `context` object and the KV cache manager. The pipeline orchestrates execution based on the inputs it receives. For more information on the stateless nature of the pipeline system, see [Stateless orchestration](#stateless-orchestration) below. ## Core components Now that you understand how modules and pipelines work together, you can explore the specific components that make up the pipeline system. ### Top-level interfaces The [`max.interfaces`](/max/api/python/interfaces/) package defines the contracts that all pipeline components must implement. These abstractions enable MAX to work uniformly across different model architectures and tasks. The key interfaces are: - [`Pipeline`](/max/api/python/interfaces#max.interfaces.Pipeline): Abstract base class for all pipelines. Defines `execute()` and `release()` methods that all pipeline implementations must provide. - [`PipelineInputs`](/max/api/python/interfaces#max.interfaces.PipelineInputs): Base class for inputs to a pipeline, such as text generation requests or embeddings requests. - [`PipelineOutput`](/max/api/python/interfaces#max.interfaces.PipelineOutput): Protocol for pipeline outputs. Must implement `is_done` to signal when generation is complete. - [`PipelineTokenizer`](/max/api/python/interfaces#max.interfaces.PipelineTokenizer): Interface for tokenizers that convert between text and token IDs, and create context objects for requests. - [`PipelineModel`](/max/api/python/pipelines/interfaces#max.pipelines.lib.interfaces.PipelineModel): Abstract base class for model implementations. Defines methods like `execute()`, `calculate_max_seq_len()`, and input preparation methods that all architectures must implement. These interfaces are task-agnostic. Specialized variants like [`TextGenerationInputs`](/max/api/python/interfaces#max.interfaces.TextGenerationInputs) and [`TextGenerationOutput`](/max/api/python/interfaces#max.interfaces.TextGenerationOutput) extend them for specific use cases. ### Pipeline registry The [`PIPELINE_REGISTRY`](/max/api/python/pipelines/registry/#max.pipelines.lib.registry.PipelineRegistry) is a singleton that tracks all available model architectures. When you run the `max serve` command with a model, the registry: 1. Looks up the model's architecture from its Hugging Face config. 2. Validates that it supports the requested encoding and settings. 3. Returns the appropriate tokenizer and pipeline for that architecture. You can interact with the registry directly to retrieve a model's tokenizer and compiled pipeline: ```python from max.pipelines import PIPELINE_REGISTRY, PipelineConfig # Create configuration for a model config = PipelineConfig( model_path="meta-llama/Llama-3.1-8B-Instruct", ) # Retrieve tokenizer and compiled pipeline tokenizer, pipeline = PIPELINE_REGISTRY.retrieve(config) # Or get a factory for deferred compilation tokenizer, pipeline_factory = PIPELINE_REGISTRY.retrieve_factory(config) pipeline = pipeline_factory() # Compile when ready ``` In this example, `retrieve()` returns a ready-to-use pipeline, while `retrieve_factory()` returns a callable that performs compilation when invoked. The factory pattern is useful when you need to pass the pipeline across process boundaries, since it avoids serializing the compiled model. ### Supported architecture A [`SupportedArchitecture`](/max/api/python/pipelines/registry/#max.pipelines.lib.registry.SupportedArchitecture) configuration defines each model architecture. This bridges the gap between Hugging Face model conventions and MAX's execution system. When you point MAX at a Hugging Face model (like `meta-llama/Llama-3.1-8B-Instruct`), MAX downloads and reads the model's `config.json` file. Inside that file is an `architectures` field listing the model class name (like `"LlamaForCausalLM"`). The registry uses this name to look up the corresponding `SupportedArchitecture`, which tells MAX which model (which subclass of `PipelineModel`) to use, what quantization formats it supports, how to load weights, and which tokenizer to instantiate. Here's how you define an architecture: ```python from max.graph.weights import WeightsFormat from max.interfaces import PipelineTask from max.nn.legacy.kv_cache import KVCacheStrategy from max.pipelines.lib import ( RopeType, SupportedArchitecture, SupportedEncoding, TextTokenizer, ) llama_arch = SupportedArchitecture( # Must match the HuggingFace model class name name="LlamaForCausalLM", # The type of task this architecture supports task=PipelineTask.TEXT_GENERATION, # Example models that use this architecture example_repo_ids=[ "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", ], # Quantization support default_encoding=SupportedEncoding.q4_k, supported_encodings={ SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], SupportedEncoding.q4_k: [KVCacheStrategy.PAGED], }, # Implementation classes pipeline_model=Llama3Model, tokenizer=TextTokenizer, # Weight handling default_weights_format=WeightsFormat.safetensors, weight_adapters={ WeightsFormat.safetensors: convert_safetensor_state_dict, WeightsFormat.gguf: convert_gguf_state_dict, }, # Architecture-specific settings rope_type=RopeType.normal, multi_gpu_supported=True, ) ``` The `name` field must match the `architectures` field in the model's Hugging Face `config.json`. Common architecture names include `LlamaForCausalLM`, `DeepseekV3ForCausalLM`, `Qwen3VLMoeForConditionalGeneration`, and more. However, if you are using a custom architecture, you will need to use a custom name that is specific to your architecture. The `task` field determines which `Pipeline` subclass MAX uses to orchestrate execution. Different types of models serve different purposes, and each task type has its own execution strategy: - **`TEXT_GENERATION`**: Autoregressive text generation for chat and completion use cases. MAX uses `TextGenerationPipeline` to handle the prefill and decode loop, KV cache management, and token sampling. - **`EMBEDDINGS_GENERATION`**: Vector embeddings from text input. MAX uses `EmbeddingsPipeline` to produce dense vector representations suitable for semantic search and retrieval. For example, if you are building a model for text generation, you will set `task=PipelineTask.TEXT_GENERATION`. Set the `default_encoding` to the quantization format you want to use by default. `supported_encodings` maps quantization formats to compatible KV cache strategies. For example, if you are using a model that supports `q4_k` quantization, you will set `supported_encodings={SupportedEncoding.q4_k: [KVCacheStrategy.PAGED]}`. For more information on quantization, see the [quantization guide](/max/graph/quantize). The `pipeline_model` field is the class that builds and executes the model. The `tokenizer` field is the class that handles text encoding and decoding. Finally, the `weight_adapters` field is a dictionary of functions that convert weights from different formats. You can use the `weight_adapters` field to convert weights from different formats to the default format. You use these fields to configure the pipeline for your model. In most cases, if you are implementing a new model architecture, you will have intimate knowledge of the model's architecture and you will be able to set these fields accordingly. ### Pipeline model The [`PipelineModel`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel) abstract class defines the interface for model implementations. Every model architecture must implement these methods: ```python from max.pipelines.lib import PipelineModel, ModelInputs, ModelOutputs class MyModel(PipelineModel): @classmethod def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int: """Return the maximum sequence length this model supports.""" ... def execute(self, model_inputs: ModelInputs) -> ModelOutputs: """Run inference on the compiled model.""" ... def prepare_initial_token_inputs( self, context_batch, kv_cache_inputs=None, return_n_logits=1 ) -> ModelInputs: """Prepare inputs for the first forward pass (prefill).""" ... def prepare_next_token_inputs( self, next_tokens, prev_model_inputs ) -> ModelInputs: """Prepare inputs for subsequent forward passes (decode).""" ... ``` In this example, these methods define the interface that all model implementations must provide. The `execute()` method runs your model's forward pass on the compiled executable. The separation between `prepare_initial_token_inputs` and `prepare_next_token_inputs` reflects the two phases of autoregressive generation: 1. **Prefill**: Process the entire prompt at once, building up the KV cache. 2. **Decode**: Generate tokens one at a time, reusing the cached keys/values. ### When to customize `ModelInputs` For most text-to-text transformer models, you can use the default `ModelInputs` implementation provided by MAX. Start with the default `ModelInputs` and only create a custom subclass if your model's forward pass requires additional tensors beyond the standard transformer inputs. Custom `ModelInputs` are only necessary for: - Models that process both text and image inputs, require custom input structures to handle image tensors, pixel values, or image embeddings alongside text tokens. - Models that have unique input patterns, for example: mixture-of-experts with routing tensors or retrieval augmented models with document embeddings. If you're implementing a standard decoder-only language model (like Llama, Mistral, or similar architectures), you likely don't need to subclass `ModelInputs`. The default implementation handles token IDs, position IDs, attention masks, and KV cache inputs, which covers most use cases. ## Pipeline execution Now let's look at how pipelines coordinate execution at runtime. When a pipeline runs, it orchestrates three main components: the KV cache manager (which tracks key-value pairs across generation steps), the model (which executes the forward pass), and the sampler (which selects the next token based on the model's output logits). The `execute()` method ties these together in a generation loop. ### Stateless orchestration A core design principle of the MAX pipeline system is that **pipelines are stateless orchestrators**. The pipeline itself does not own or maintain per-request state. Instead, it operates on the state passed to it through inputs: - **Context objects** track all request-specific information (tokens, sampling parameters, generation status). You pass these into `execute()`, and the pipeline updates them but doesn't store them internally. - **KV cache manager** owns the allocation and lifecycle of cached key-value pairs across all requests. The pipeline uses the cache manager but doesn't own it. ## Custom architecture registration Now let's see how you can extend MAX with your own model architectures. MAX uses a convention-based registration system. When you point MAX at a custom architecture directory, it imports the module and looks for an `ARCHITECTURES` list to register. ### Custom architecture structure A custom architecture directory typically contains these files: ```text my_model/ ├── __init__.py # Exports ARCHITECTURES list ├── arch.py # Defines SupportedArchitecture config ├── model.py # Implements PipelineModel subclass with model architecture ├── model_config.py # (Optional) Custom model configuration └── weight_adapters.py # (Optional) Functions to convert weight formats ``` The `__init__.py` file exports an `ARCHITECTURES` list that MAX discovers: ```python # my_model/__init__.py from .arch import my_arch ARCHITECTURES = [my_arch] ``` The `arch.py` file defines the architecture configuration: ```python # my_model/arch.py from max.graph.weights import WeightsFormat from max.interfaces import PipelineTask from max.nn.legacy.kv_cache import KVCacheStrategy from max.pipelines.lib import ( SupportedArchitecture, SupportedEncoding, TextTokenizer, ) from .model import MyModel my_arch = SupportedArchitecture( name="MyModelForCausalLM", # Must match HuggingFace config task=PipelineTask.TEXT_GENERATION, default_encoding=SupportedEncoding.bfloat16, supported_encodings={ SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], }, pipeline_model=MyModel, tokenizer=TextTokenizer, default_weights_format=WeightsFormat.safetensors, ) ``` ### Load custom architectures Use the `--custom-architectures` flag to load your architecture: ```bash max serve --custom-architectures ./my_model --model path/to/weights ``` MAX imports your module, finds the `ARCHITECTURES` list, and registers each architecture with `PIPELINE_REGISTRY`. Your custom architecture then overrides any built-in architecture with the same name. For a complete working example, see the [custom-models example](https://github.com/modular/modular/tree/main/max/examples/custom-models) in the Modular repository. ## Configuration flow Now let's see how configuration flows from user input to a running pipeline. When you run a model, configuration flows through several layers: 1. **Start with user arguments**: MAX collects CLI or API arguments into a `PipelineConfig` object that specifies the model path, quantization settings, and runtime parameters. 2. **Load model metadata**: The registry fetches the model's Hugging Face config to perform architecture lookup and extract hyperparameters like hidden size, number of layers, and vocabulary size. 3. **Validate compatibility**: The registry checks that the architecture supports the requested encoding and KV cache strategy. 4. **Instantiate pipeline**: Finally, the registry constructs and returns the tokenizer and compiled pipeline ready for inference. ```python from max.pipelines import PipelineConfig config = PipelineConfig( # Model specification (Hugging Face repo ID or local path) model_path="meta-llama/Llama-3.1-8B-Instruct", # Sequence limits max_length=4096, # Batching max_batch_size=32, ) ``` The `PipelineConfig` consolidates all settings and provides defaults based on the model and hardware. See the [`PipelineConfig` reference](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig) for all available options. ## Next steps Now that you understand the pipeline architecture, continue learning: - [Build an MLP block as a module](/max/develop/build-an-mlp-block): Start building custom modules using the `max.nn` package. - [Serve custom model architectures](/max/develop/serve-custom-model-architectures): Implement a complete custom architecture with the pipeline system. - [Bring your own fine-tuned model to MAX pipelines](/max/develop/max-pipeline-bring-your-own-model): Use your own fine-tuned weights with MAX pipelines. --- ## Serve custom model architectures import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; MAX comes with built-in support for popular model architectures like `Gemma3ForCausalLM`, `Qwen2ForCausalLM`, and `LlamaForCausalLM`, so you can instantly deploy them by passing a specific Hugging Face model name to the `max serve` command (explore [our model repo](https://builds.modular.com/?category=models)). You can also use MAX to serve a custom model architecture with the `max serve` command, which provides an [OpenAI-compatible API](/max/api/serve). In this tutorial, you'll implement a custom architecture based on the Qwen2 model by extending MAX's existing Llama3 implementation. This approach demonstrates how to leverage MAX's built-in architectures to quickly support new models with similar structures. By the end of this tutorial, you'll understand how to: - Set up the required file structure for custom architectures. - Extend existing MAX model implementations. - Register your model architecture with MAX. - Serve your model and make inference requests. ## Set up your environment Create a Python project and install the necessary dependencies: ## Understand the architecture structure Before creating your custom architecture, let's understand how to organize your custom model project. Create the following structure in your project directory: ```text qwen2/ ├── __init__.py ├── arch.py └── model.py ``` Here's what each file does: - **`__init__.py`**: Makes your architecture discoverable by MAX. - **`arch.py`**: Registers your model with MAX, specifying supported encodings, capabilities, and which existing components to reuse. - **`model.py`**: Contains your model implementation that extends an existing MAX model class. - **`model_config.py`**: Contains the model configuration class that can be instantiated from a PipelineConfig. When extending an existing architecture, you can often reuse configuration handling and weight adapters from the parent model, significantly reducing the amount of code you need to write. ## Implement the main model class When your model is similar to an existing architecture, you can extend that model class instead of building from scratch. In this example, we'll extend the `Llama3Model` class to implement the `Qwen2Model` class: ```python title="model.py" from __future__ import annotations from max.driver import Device from max.engine import InferenceSession from max.graph.weights import Weights, WeightsAdapter from max.nn.legacy import ReturnLogits from max.pipelines.architectures.llama3.model import Llama3Model from max.pipelines.lib import KVCacheConfig, PipelineConfig, SupportedEncoding from transformers import AutoConfig class Qwen2Model(Llama3Model): """Qwen2 pipeline model implementation.""" attention_bias: bool = True """Whether to use attention bias.""" def __init__( self, pipeline_config: PipelineConfig, session: InferenceSession, huggingface_config: AutoConfig, encoding: SupportedEncoding, devices: list[Device], kv_cache_config: KVCacheConfig, weights: Weights, adapter: WeightsAdapter | None = None, return_logits: ReturnLogits = ReturnLogits.LAST_TOKEN, ) -> None: super().__init__( pipeline_config, session, huggingface_config, encoding, devices, kv_cache_config, weights, adapter, return_logits, ) ``` By inheriting from `Llama3Model`, the Qwen2 implementation automatically gets: - The [`execute`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.execute), [`prepare_initial_token_inputs`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.prepare_initial_token_inputs), and [`prepare_next_token_inputs`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.prepare_next_token_inputs) methods required by MAX. - Graph building logic for transformer architectures. - Configuration handling from Hugging Face models. - Weight loading and conversion capabilities. The only modification needed is setting `attention_bias = True` to match Qwen2's architecture specifics. This approach works because Qwen2 and Llama3 share similar transformer architectures. ## Implement the model config class Qwen2's config can be implemented by inheriting most of the features from Llama3's config, except override the attention bias. ```python title="model_config.py" from dataclasses import dataclass from typing import Literal from max.graph.weights import WeightData from max.nn.legacy.transformer import ReturnHiddenStates, ReturnLogits from max.pipelines.architectures.llama3.model_config import Llama3Config from transformers import AutoConfig @dataclass(kw_only=True) class Qwen2Config(Llama3Config): """Model configuration for Qwen2 graph construction/execution.""" def finalize( self, huggingface_config: AutoConfig, state_dict: dict[str, WeightData], return_logits: ReturnLogits, return_hidden_states: ReturnHiddenStates = ReturnHiddenStates.NONE, norm_method: Literal["rms_norm"] | Literal["layer_norm"] = "rms_norm", attention_bias: bool = False, ) -> None: super().finalize( huggingface_config=huggingface_config, state_dict=state_dict, return_logits=return_logits, return_hidden_states=return_hidden_states, norm_method=norm_method, attention_bias=True, # Qwen2 uses attention bias ) ``` ## Define your architecture registration The `arch.py` file that tells MAX about your model's capabilities. When extending an existing architecture, you can reuse many components: ```python title="arch.py" from max.graph.weights import WeightsFormat from max.interfaces import PipelineTask from max.nn.legacy.kv_cache import KVCacheStrategy from max.pipelines.architectures.llama3 import weight_adapters from max.pipelines.lib import ( RopeType, SupportedArchitecture, SupportedEncoding, TextTokenizer, ) from .model import Qwen2Model from .model_config import Qwen2Config qwen2_arch = SupportedArchitecture( name="Qwen2ForCausalLM", task=PipelineTask.TEXT_GENERATION, example_repo_ids=["Qwen/Qwen2.5-7B-Instruct", "Qwen/QwQ-32B"], default_weights_format=WeightsFormat.safetensors, default_encoding=SupportedEncoding.bfloat16, supported_encodings={ SupportedEncoding.float32: [KVCacheStrategy.PAGED], SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], }, pipeline_model=Qwen2Model, tokenizer=TextTokenizer, rope_type=RopeType.normal, weight_adapters={ WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, WeightsFormat.gguf: weight_adapters.convert_gguf_state_dict, }, config=Qwen2Config, ) ``` This configuration demonstrates several key features of MAX's architecture system. The [`name`](/max/api/python/pipelines/registry/#max.pipelines.lib.registry.SupportedArchitecture) parameter must match the model class name in Hugging Face configs, while `task` specifies the pipeline task type using `PipelineTask` from `max.interfaces`. The `rope_type` parameter specifies the type of rotary position embeddings used by the model. One of the significant advantages of extending existing architectures is the ability to reuse components. In this case, we're reusing Llama3's weight adapters instead of creating custom ones, which handles the conversion between different weight formats like SafeTensors and GGUF. This reuse pattern is common when extending existing architectures—you can often leverage adapters, configuration handling, and other utilities from the parent model. ## Load your architecture Create an `__init__.py` file to make your architecture discoverable by MAX: ```python title="__init__.py" from .arch import qwen2_arch ARCHITECTURES = [qwen2_arch] __all__ = ["qwen2_arch", "ARCHITECTURES"] ``` MAX automatically loads any architectures listed in the `ARCHITECTURES` variable when you specify your module with the [`--custom-architectures`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig.custom_architectures) flag. ## Test your custom architecture You can now test your custom architecture using the `--custom-architectures` flag. From your project directory, run the following command: ```bash max serve \ --model Qwen/Qwen2.5-7B-Instruct \ --custom-architectures qwen2 ``` The `--model` flag tells MAX to use a specified model. You can specify the model path to a Hugging Face model, or a local directory containing a model. While the `--custom-architectures` flag tells MAX to load custom architectures from the specified Python module that we just built. :::caution Trust remote code Some models require executing custom code from their repository. If you encounter an error about "trust_remote_code", add the `--trust-remote-code` flag: ```bash max serve \ --model Qwen/Qwen2.5-7B-Instruct \ --custom-architectures qwen2 \ --trust-remote-code ``` Only use `--trust-remote-code` with models you trust, as it allows executing arbitrary code from the model repository. ::: The server is ready when you see this message: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Now you can test your custom architecture. If you implemented an architecture to do [text generation](/max/api/python/pipelines/core#max.pipelines.core.PipelineTask.TEXT_GENERATION), you can send a request to that endpoint. For example: ```bash curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen2.5-7B-Instruct", "messages": [ {"role": "user", "content": "Hello! Can you help me with a simple task?"} ], "max_tokens": 100 }' ``` ```python from openai import OpenAI client = OpenAI( base_url="http://localhost:8000/v1", api_key="EMPTY", # Required by API but not used by MAX ) response = client.chat.completions.create( model="Qwen/Qwen2.5-7B-Instruct", messages=[ {"role": "user", "content": "Hello! Can you help me with a simple task?"} ], max_tokens=100, ) print(response.choices[0].message.content) ``` ## Next steps Congratulations! You've successfully created a custom architecture for MAX pipelines and served it with the `max serve` command. While this tutorial showed the simplified approach of extending an existing architecture, you may need to implement a model from scratch if your architecture differs significantly from MAX's built-in models. In that case, you would: 1. Implement the full [`PipelineModel`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel) interface including [`execute`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.execute), [`prepare_initial_token_inputs`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.prepare_initial_token_inputs), and [`prepare_next_token_inputs`](/max/api/python/pipelines/pipeline/#max.pipelines.lib.interfaces.pipeline_model.PipelineModel.prepare_next_token_inputs) methods. 2. Create custom configuration classes to handle model parameters. 3. Write custom weight adapters for converting between different formats. 4. Build the computation graph using MAX's graph API. For implementation details, explore the existing [supported model architectures](https://github.com/modular/modular/tree/main/max/python/max/pipelines/architectures) on GitHub. Each subdirectory represents a different model family with its own implementation. You can examine these architectures to understand different approaches and find the best base for your custom architecture. For a hands-on guide to implementing model architectures from the ground up, explore our [Build an LLM from scratch with MAX](https://llm.modular.com/) tutorial, which walks you through building a transformer architecture step-by-step. Here are some areas to explore further: export const docs = [ '../../graph/quantize.mdx', '../../develop/custom-ops.mdx', '../../develop/build-an-mlp-block.mdx', ]; --- ## Tensor fundamentals Tensors are the fundamental building blocks of neural network models in MAX. They represent multi-dimensional arrays of numbers used for model inputs, outputs, and parameters. If you're coming from NumPy, think of a MAX [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) like a NumPy [`ndarray`](https://numpy.org/devdocs/reference/arrays.ndarray.html#arrays-ndarray). A tensor is a multi-dimensional array of numbers. You can think of tensors as: - Rank 0 (scalar): a single number (like `72.5`) - Rank 1 (vector): a 1-D tensor or list of numbers (like `[1, 2, 3, 4]`) - Rank 2 (matrix): a 2-D tensor or table of numbers (like a spreadsheet with rows and columns) - Rank 3+ (higher-dimensional arrays): a 3-or-more dimensional tensor (stacks of matrices or more complex structures) :::tip Coming from PyTorch or NumPy? There are some key differences between tensors in MAX and tensors in PyTorch or NumPy. For example, tensors in MAX are created using the `Tensor.constant()` function instead of direct construction, many math operations use the functional API (`F.sqrt()` instead of `.sqrt()`), and reduction operations keep dimensions by default. ::: ## Create a tensor You can create a tensor from a Python list using the [`Tensor.constant()`](/max/api/python/tensor#max.tensor.Tensor.constant) function: {/* @sync: _examples/tensors/create_first_tensor.py */} ```python from max.tensor import Tensor # Create a simple 1-D tensor (a vector) x = Tensor.constant([1, 2, 3, 4, 5]) print(x) ``` This imports the [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) class and creates a 1-D tensor (a vector) with the values `[1, 2, 3, 4, 5]`. The `constant()` function creates a tensor from a tensor-like object; meaning, it can be a list, a NumPy array, or a PyTorch tensor. The expected output is: ```output TensorType(dtype=float32, shape=[Dim(5)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0] ``` :::note Performance note `Tensor.constant()` performance optimizations are still being improved. In the meantime, if you need better performance when creating tensors on accelerators, use: ```python import numpy as np from max.driver import Accelerator from max.tensor import Tensor np_array = np.array([1.0, 2.0, 3.0]) tensor = Tensor.from_numpy(np_array).to(Accelerator()) ``` ::: :::note About dtype and device Notice that when creating the tensors, the output shows the dtype and device. The dtype is the data type of the tensor, and the device is the device on which the tensor will be stored. We'll cover dtypes and devices in more detail in the [Data Types (dtype)](/max/develop/dtypes) and [Devices](/max/develop/tensors#device) sections. ::: You can also create tensors with any number of dimensions by nesting lists: {/* @sync: _examples/tensors/create_multi_dim.py */} ```python from max.tensor import Tensor # Create a 2-D tensor (a matrix) matrix = Tensor.constant([[1, 2, 3], [4, 5, 6]]) print(matrix) # Create a 3-D tensor (a cube of numbers) cube = Tensor.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(cube) ``` The expected output is: ```output TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] TensorType(dtype=float32, shape=[Dim(2), Dim(2), Dim(2)], device=cpu:0): [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ``` MAX provides convenient functions for creating tensors with specific patterns, such as [`ones()`](/max/api/python/tensor#max.tensor.Tensor.ones) and [`zeros()`](/max/api/python/tensor#max.tensor.Tensor.zeros): {/* @sync: _examples/tensors/create_factory.py */} ```python from max.driver import CPU from max.dtype import DType from max.tensor import Tensor # Tensor filled with ones ones = Tensor.ones([3, 4], dtype=DType.float32, device=CPU()) print(ones) # Tensor filled with zeros zeros = Tensor.zeros([2, 3], dtype=DType.float32, device=CPU()) print(zeros) ``` The expected output is: ```output TensorType(dtype=float32, shape=[Dim(3), Dim(4)], device=cpu:0): [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] TensorType(dtype=float32, shape=[Dim(2), Dim(3)], device=cpu:0): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ``` For random data, you can use one of several functions from the [`random`](/max/api/python/random) module, such as [`random.normal()`](/max/api/python/random#max.random.normal): {/* @sync: _examples/tensors/create_random.py */} ```python from max import random # Random values from a normal distribution random_tensor = random.normal([3, 3]) print(random_tensor) ``` The expected output is: ```output TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [1.6810914278030396, 2.3331382274627686, -0.25120288133621216, 0.8896129131317139, 1.6362168788909912, -1.9282348155975342, -0.4372555911540985, -0.8747910261154175, 0.5068135857582092] ``` You can adjust the mean and standard deviation of the normal distribution to create tensors with different values: ```python random_tensor = random.normal([3, 3], mean=0.0, std=1.0) print(random_tensor) ``` The expected output will vary (since it's random), but will look similar to this: ```output TensorType(dtype=float32, shape=[Dim(3), Dim(3)], device=cpu:0): [-0.5, 1.2, -0.8, 0.3, -1.1, 0.7, 0.2, -0.4, 0.9] ``` :::note The actual values will be different each time you run the code, as they are randomly generated from a normal (Gaussian) distribution with mean 0.0 and standard deviation 1.0. ::: ## Tensor properties Every tensor has several key properties that tell you about its structure and contents. ### Shape The *shape* tells you the size of each dimension of the tensor. It's a list where each number represents the size along that dimension: {/* @sync: _examples/tensors/properties_shape.py */} ```python from max.tensor import Tensor # 1-D tensor x = Tensor.constant([1, 2, 3, 4]) print(x.shape) # 2-D tensor matrix = Tensor.constant([[1, 2, 3], [4, 5, 6]]) print(matrix.shape) # 3-D tensor cube = Tensor.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(cube.shape) ``` The expected output is: ```output [Dim(4)] [Dim(2), Dim(3)] [Dim(2), Dim(2), Dim(2)] ``` In this example, the [`shape`](/max/api/python/tensor#max.tensor.Tensor.shape) property returns a list of `Dim` objects representing the size of each dimension. ### Rank The *rank* (also called number of dimensions) tells you how many axes the tensor has: {/* @sync: _examples/tensors/properties_rank.py */} ```python from max.tensor import Tensor scalar = Tensor.constant([42]) # Rank 1 (it's a 1-element vector) vector = Tensor.constant([1, 2, 3]) # Rank 1 matrix = Tensor.constant([[1, 2], [3, 4]]) # Rank 2 cube = Tensor.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # Rank 3 print(vector.rank) # 1 print(matrix.rank) # 2 print(cube.rank) # 3 ``` ### Data type (dtype) *Dtype*, or data type, is a property of a tensor that tells you the type of the data stored in the tensor. For example, a tensor can contain integers: `int32`, floating-point numbers: `float32`, or booleans: `bool`. {/* @sync: _examples/tensors/properties_dtype.py */} ```python from max.dtype import DType from max.tensor import Tensor # Float tensor (default for most operations) floats = Tensor.ones([2, 2], dtype=DType.float32) # Integer tensor integers = Tensor.ones([2, 2], dtype=DType.int32) ``` MAX supports a wide range of dtypes, including all standard NumPy and PyTorch dtypes. You will learn more about dtypes in the [Data Types (dtype)](/max/develop/dtypes) section. ### Device The *device* tells you where the tensor operation occurs and is stored: {/* @sync: _examples/tensors/properties_device.py */} ```python from max.driver import CPU from max.tensor import Tensor # Tensor on CPU cpu_tensor = Tensor.ones([2, 2], device=CPU()) ``` When you need to perform operations on a tensor, you need to specify the device on which the tensor is stored. Currently, MAX supports CPU and GPU devices by specifying either the `CPU()` or `Accelerator()` class. ### Total elements You can check how many numbers are stored in the tensor: {/* @sync: _examples/tensors/properties_num_elements.py */} ```python from max.tensor import Tensor t = Tensor.constant([[1, 2, 3], [4, 5, 6]]) print(t.num_elements()) # 6 (2 rows × 3 columns) ``` In this example, the expected output is: ```output 6 ``` The [`num_elements()`](/max/api/python/tensor#max.tensor.Tensor.num_elements) function returns the total number of elements in the tensor, which is the product of the dimensions. ## Next steps Now that you understand tensor basics, continue to [Data Types (dtype)](/max/develop/dtypes) to learn how to control tensor precision and memory usage. --- ## Environment variables This page documents all environment variables you can use to configure MAX behavior. These variables control server settings, logging, telemetry, performance, and integrations. ## How to set environment variables You can set environment variables in several ways: ```bash # Export in your shell export MAX_SERVE_HOST="0.0.0.0" # Pass to Docker container docker run --env "MAX_SERVE_HOST=0.0.0.0" modular/max-nvidia-full:latest ... # Use a .env file in your working directory echo "MAX_SERVE_HOST=0.0.0.0" >> .env ``` ### Configuration precedence When the same setting is configured in multiple places, the following precedence applies (highest to lowest): 1. **CLI flags or direct Python initialization**: For example, `--port 8080` or `Settings(MAX_SERVE_PORT=8080)`. CLI flags are passed directly to the `Settings` constructor, so they have the same precedence as direct Python initialization. 2. **Environment variables**: `export MAX_SERVE_HOST="0.0.0.0"` 3. **`.env` file values**: Values defined in a `.env` file in your working directory :::note Environment variables must use the exact names documented below. ::: ## Serving These variables configure the MAX model serving behavior. For more information on serving a model with MAX, explore the [text to text](/max/inference/text-to-text) and [image to text](/max/inference/image-to-text) guides. | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `MAX_SERVE_HOST` | Hostname for the MAX server | String | `0.0.0.0` | | `MAX_SERVE_PORT` | Port for serving MAX | Integer | `8000` | | `MAX_SERVE_METRICS_ENDPOINT_PORT` | Port for the Prometheus metrics endpoint | Integer | `8001` | | `MAX_SERVE_ALLOWED_IMAGE_ROOTS` | Allowed root directories for `file://` URI access | Comma-separated paths | Empty | | `MAX_SERVE_MAX_LOCAL_IMAGE_BYTES` | Maximum size in bytes for local image files | Integer | `20971520` (20 MiB) | ## Logging These variables control logging behavior and verbosity. You can read more about logs when using the [MAX container](/max/container#logs). | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `MAX_SERVE_LOGS_CONSOLE_LEVEL` | Console log verbosity level | `CRITICAL`, `ERROR`, `WARNING`, `INFO`, `DEBUG` | `INFO` | | `MODULAR_STRUCTURED_LOGGING` | Enable JSON-formatted structured logging for deployed services | `0`, `1` | `1` | | `MAX_SERVE_LOGS_FILE_PATH` | Path to write log files | File path | None | | `MAX_SERVE_LOG_PREFIX` | Prefix to prepend to all log messages | String | None | ## Telemetry and metrics These variables control telemetry collection and metrics reporting. For more information, read about [MAX container telemetry](/max/container#telemetry). | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `MAX_SERVE_DISABLE_TELEMETRY` | Disable remote telemetry collection | `0`, `1` | `0` | | `MODULAR_USER_ID` | User identifier for telemetry (e.g., your company name) | String | None | | `MAX_SERVE_DEPLOYMENT_ID` | Deployment identifier for telemetry (e.g., your application name) | String | None | | `MAX_SERVE_METRIC_LEVEL` | Level of detail in metrics emitted | `NONE`, `BASIC`, `DETAILED` | `BASIC` | ## Debugging and profiling These variables enable debugging and profiling features. For more information about profiling, see [GPU profiling with Nsight Systems](/max/gpu-system-profiling). | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `MODULAR_MAX_DEBUG` | Enable stack traces for MAX graph compilation errors | `True`, `False` | `False` | | `MODULAR_ENABLE_PROFILING` | Enable runtime profiling and tracing | `off`, `on`, `detailed` | `off` | ## Performance and caching The following variables configure caching and memory behavior. | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `MODULAR_MAX_CACHE_DIR` | Directory to save MAX model cache for reuse | Path | See note below | | `MODULAR_CACHE_DIR` | Configure cache directory for all Modular filesystems | Path | See note below | | `MODULAR_MAX_SHM_WATERMARK` | Percentage of `/dev/shm` to allocate for shared memory. Set to `0.0` to disable shared memory. | Float (0.0–1.0) | `0.9` | ## Hugging Face Configure your Hugging Face integration with the following environment variable: | Variable | Description | Values | Default | |----------|-------------|--------|---------| | `HF_TOKEN` | Hugging Face authentication token for accessing gated models | String | None | ## Related resources - [MAX container](/max/container) - Deploy MAX with Docker - [`max serve` CLI](/max/cli/serve) - Command-line options for serving --- ## FAQ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; If this page doesn't answer your question, please ask us on our [Modular forum](https://forum.modular.com) or [Discord channel](https://www.discord.gg/modular). ## Distribution ### What operating systems do you support? {#system-requirements} You can install `modular` on Mac and Linux operating systems. For more details, see the [system requirements](/max/packages#system-requirements). ### What are the GPU requirements? {#gpu-requirements} The Modular Platform supports both CPUs and GPUs, so you don't always need a GPU to serve a model—although some larger models do require a GPU. For details about GPU support, see our [list of compatible GPUs](/max/packages/#gpu-compatibility). ### Will MAX be open-sourced? We want to contribute a lot to open source, but we also want to do it right. Our team has decades of experience building open-source projects, and we believe it's very important to create an inclusive and vibrant community, which takes a lot of work. We've already begun open-sourcing parts of the MAX framework, including our [Python serving library](https://github.com/modular/modular/tree/main/max/python/max/serve), [MAX model architectures](https://github.com/modular/modular/tree/main/max/python/max/pipelines/architectures), and [GPU kernels](https://github.com/modular/modular/tree/main/max/kernels/src/nn). To get the latest updates, [join our community](https://www.modular.com/community). ## Functionality ### What clouds and services can I deploy MAX onto? You can deploy our MAX container across a variety of VM and Kubernetes-based cloud services, including AWS, GCP, and Azure. To get started with any of them, check out our [tutorials using MAX Serve](/max/tutorials?filterByTags&tag=serve). ### Can I run MAX locally? Yes. MAX has support for MacOS and ARM hardware, meaning it can be run on your local laptop for exploration and testing purposes. ### Will MAX support distributed inference of large models? Yes, it will support executing large models that do not fit into the memory of a single device. This isn't available yet, so stay tuned! ## Installation ### Can I install both stable and nightly builds? Yes, it's safe and easy to use the stable and nightly builds for different projects, each with their own virtual environment and package dependencies. For more information, read the [packages guide](/max/packages). ### Does the MAX SDK collect telemetry? Yes, the MAX SDK collects basic system information, session durations, compiler events, and crash reports that enable us to identify, analyze, and prioritize issues. The MAX container for model serving also collects performance metrics such as time to first token and input processing time. This telemetry is crucial to help us quickly identify problems and improve our products for you. Without this telemetry, we would rely solely on user-submitted bug reports, which are limited and would severely limit our performance insights. To disable serving telemetry, see the [MAX container documentation](/max/container#metrics). --- ## Quickstart import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import ModelDropdownTabs from '@site/src/components/ModelDropdownTabs'; import ContactSection from '@site/src/components/ContactSection'; import Requirements from '@site/src/components/Requirements'; import { requirementsNoMacWithGPU } from './requirements'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; import { ModelSelector, DynamicCode, ConditionalContent } from '@site/src/components/ModelSelector'; A major component of the Modular Platform is MAX, our developer framework that abstracts away the complexity of building and serving high-performance GenAI models on a wide range of hardware, including NVIDIA and AMD GPUs. In this quickstart, you'll create an endpoint for an open-source LLM using MAX, run an inference from a Python client, and then benchmark the endpoint. :::caution GPU required We recommend using on an **NVIDIA B200 / H200 / H100** or **AMD MI355X / MI325X / MI300X**. MAX can serve models on a wide range of CPUs and GPUs, but the LLMs most customers want require a lot of memory, which is why we suggest production-grade GPUs. This guide does offer some smaller models, but to use the latest LLMs, you'll still need a [compatible GPU](/max/packages#gpu-compatibility). ::: System requirements: If you'd rather create an endpoint with Docker, see our [tutorial to benchmark MAX](/max/deploy/benchmark). ## Set up your project First, install the `max` CLI that you'll use to start the model endpoint. :::tip For the most reliable experience, we recommend installing with `pixi`. ::: export const textModelOptions = [ { label: 'gemma-3-4b-it', value: 'google/gemma-3-4b-it', description: 'Requires >8 GiB of GPU RAM — works on most compatible GPUs' }, { label: 'gemma-3-12b-it', value: 'google/gemma-3-12b-it', description: 'Requires >24 GiB of GPU RAM — we suggest an A100, MI300, or better' }, { label: 'gemma-3-27b-it', value: 'google/gemma-3-27b-it', description: 'Requires >60 GiB of GPU RAM — we suggest an H100, MI300, or better', default: true }, { label: 'Llama-3.1-8B-Instruct', value: 'meta-llama/Llama-3.1-8B-Instruct', description: 'Requires >15 GiB of RAM (GPU or CPU) — use this if you\'re on a Mac' }, ]; export const imageModelOptions = [ { label: 'gemma-3-12b-it', value: 'google/gemma-3-12b-it', description: 'Requires >24 GiB of GPU RAM — we suggest an A100, MI300, or better' }, { label: 'gemma-3-27b-it', value: 'google/gemma-3-27b-it', description: 'Requires >60 GiB of GPU RAM — we suggest an H100, MI300, or better', default: true }, { label: 'InternVL3-8B-Instruct', value: 'OpenGVLab/InternVL3-8B-Instruct', description: 'Requires >22 GiB of GPU RAM — we suggest an A100, MI300, or better', }, { label: 'InternVL3-14B-Instruct', value: 'OpenGVLab/InternVL3-14B-Instruct', description: 'Requires >36 GiB of GPU RAM — we suggest an H100, MI300, or better', }, ]; export const allModelOptions = { text: textModelOptions, image: imageModelOptions } ## Start a model endpoint Now you'll serve an LLM from a local endpoint using [`max serve`](/max/cli/serve). First, pick whether you want to perform text-to-text inference or image-to-text (multimodal) inference, and then select a model size. We've included a small number of model options to keep it simple, but you can explore more models in our [model repository](https://builds.modular.com/?category=models). Select a model to change the code below: model.includes('gemma')} > Google's [Gemma 3 models](https://builds.modular.com/models/gemma-3-it/27B) are multimodal. MAX supports text input for all available sizes and image input for the 12B and 27B models. All sizes require a [compatible GPU](/max/packages#gpu-compatibility). model.includes('Llama')} > Meta's [Llama 3.1 models](https://builds.modular.com/models/Llama-3.1-Instruct/8B) are text-only LLMs. You can pick any model in the family, but we suggest the smaller 8B model because it works on a wide range of CPUs, including on Macs. Start the endpoint with the `max` CLI: 1. Add your [HF Access Token](https://huggingface.co/settings/tokens) as an environment variable: ```sh export HF_TOKEN="hf_..." ``` model.includes('gemma')} > 2. Agree to the [Gemma 3 license](https://huggingface.co/google/gemma-3-27b-it). model.includes('Llama')} > 2. Agree to the [Llama 3.1 license](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). 3. Start the endpoint: { `max serve --model {text}` } Select a model to change the code below: model.includes('gemma')} > Google's [Gemma 3 models](https://builds.modular.com/models/gemma-3-it/27B) are multimodal. MAX supports text input for all available sizes and image input for the 12B and 27B models. All sizes require a [compatible GPU](/max/packages#gpu-compatibility). model.includes('InternVL3')} > OpenGVLab's multimodal [InternVL3 models](https://builds.modular.com/models/InternVL3/14B) come in many sizes, but they all require a [compatible GPU](/max/packages#gpu-compatibility). They aren't gated on Hugging Face, so you don't need to provide a Hugging Face access token to start the endpoint. model.includes('gemma')} > Agree to the [Gemma 3 license](https://huggingface.co/google/gemma-3-27b-it) and add your [HF Access Token](https://huggingface.co/settings/tokens) as an environment variable: ```bash export HF_TOKEN="hf_..." ``` Start the endpoint with the `max` CLI: { `max serve --model {image} --trust-remote-code` } It will take some time to download the model, compile it, and start the server. While that's working, you can get started on the next step. ## Run inference with the endpoint Open a new terminal and send an inference request using the `openai` Python API: 1. Navigate to the project you created above and then install the `openai` package: ```bash pixi add openai ``` ```bash uv add openai ``` ```bash pip install openai ``` ```bash conda install -c conda-forge openai ``` 2. Activate the virtual environment: ```bash pixi shell ``` ```bash source .venv/bin/activate ``` ```bash source .venv/quickstart/bin/activate ``` ```bash conda init ``` Or if you're on a Mac, use: ```bash conda init zsh ``` 3. Create a new file that sends an inference request: { `from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") completion = client.chat.completions.create( model="{text}", messages=[ { "role": "user", "content": "Who won the world series in 2020?" }, ], ) print(completion.choices[0].message.content)` } Notice that the `OpenAI` API requires the `api_key` argument, but you don't need that with MAX. 4. Wait until the model server is ready—when it is, you'll see this message in your first terminal: ```output 🚀 Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Then run the Python script from your second terminal, and you should see results like this (your results may vary, especially for different model sizes): ```sh python generate-text.py ``` ```output The **Los Angeles Dodgers** won the World Series in 2020! They defeated the Tampa Bay Rays 4 games to 2. It was their first World Series title since 1988. ``` 1. Navigate to the project you created above and then install the `openai` package: ```bash pixi add openai ``` ```bash uv add openai ``` ```bash pip install openai ``` ```bash conda install -c conda-forge openai ``` 2. Activate the virtual environment: ```bash pixi shell ``` ```bash source .venv/bin/activate ``` ```bash source .venv/quickstart/bin/activate ``` ```bash conda init ``` Or if you're on a Mac, use: ```bash conda init zsh ``` 3. Create a new file that sends an inference request: { `from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") completion = client.chat.completions.create( model="{image}", messages=[ { "role": "user", "content": [ { "type": "text", "text": "Write a caption for this image" }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], max_tokens=300 ) print(completion.choices[0].message.content)` } Notice that the `OpenAI` API requires the `api_key` argument, but you don't need that with MAX. 4. Wait until the model server is ready—when it is, you'll see this message in your first terminal: ```output 🚀 Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Then run the Python script from your second terminal, and you should see results like this (your results will always be different): ```sh python generate-image-caption.py ``` ```output In a charming English countryside setting, Mr. Bun, dressed elegantly in a tweed outfit, stands proudly on a dirt path, surrounded by lush greenery and blooming wildflowers. ``` ## Benchmark the endpoint While still in your second terminal, run the following command to benchmark your endpoint: { `max benchmark \\ --model {text} \\ --backend modular \\ --endpoint /v1/chat/completions \\ --dataset-name sonnet \\ --num-prompts 500 \\ --sonnet-input-len 550 \\ --output-lengths 256 \\ --sonnet-prefix-len 200` } { `max benchmark \\ --model {image} \\ --backend modular \\ --endpoint /v1/chat/completions \\ --dataset-name random \\ --num-prompts 500 \\ --random-input-len 40 \\ --random-output-len 150 \\ --random-image-size 512,512 \\ --random-coefficient-of-variation 0.1,0.6` } When it's done, you'll see the results printed to the terminal. If you want to save the results, add the `--save-result` flag and it'll save a JSON file in the local directory. You can specify the file name with `--result-filename` and change the directory with `--result-dir`. For example: ```sh max benchmark \ ... --save-result \ --result-filename "quickstart-benchmark.json" \ --result-dir "results" ``` The benchmark options above are just a starting point. When you want to save your own benchmark configurations, you can define them in a YAML file and pass it to the `--config-file` option. For example configurations, see our [benchmark config files on GitHub](https://github.com/modular/modular/tree/main/max/python/max/benchmark/configs). For more details about the tool, including other datasets and configuration options, see the [`max benchmark` documentation](/max/cli/benchmark). :::caution GPU ran out of memory If the server log says, `GPU ran out of memory during model execution`, try reducing the benchmark input length with the option corresponding to your dataset (`--sonnet-input-len` or `--random-input-len`). Also consider restarting `max serve` and adding `--device-memory-utilization` with a value as low as `0.5` (the default is `0.9`). ::: ## Next steps Now that you have an endpoint, connect to it with our [Agentic Cookbook](https://modul.ar/cookbook)—an open-source project for building React-based interfaces for any model endpoint. Just clone the repo, run it with npm, and pick a recipe such as a chat interface, a drag-and-drop image caption tool, or build your own. To get started, see the [project README](https://modul.ar/cookbook). {/* ### Keep reading Here are some features you can try with your endpoint: import SmallCards from '@site/src/components/SmallCards'; export const docs = [ { title: 'Function calling', link: '/max/serve/function-calling', description: `Learn how to use LLMs that support function calling and tool use, such as to retrieve data and execute external tasks.`, }, { title: 'Structured output', link: '/max/serve/structured-output', description: `Learn how to use structured output (constrained decoding) to enforce the output format from a model using an input schema.`, }, ]; */} ## Stay in touch --- ## GPU profiling with Nsight Systems Nsight Systems (`nsys`) is a system-wide performance analysis tool for visualizing application behavior on NVIDIA GPUs. It captures the full execution timeline including CPU activity, CUDA API calls, GPU kernels, and memory operations. This page describes how to use `nsys` to profile systems built with the MAX framework and view the results with the Nsight Systems application. You should use `nsys` when you need to understand _where_ time is being spent across your application, such as which kernels are slow. You can then deep-dive into specific kernel performance [Nsight Compute](https://github.com/modular/modular/blob/main/max/docs/kernel-profiling.md), or our purpose-built [`kbench` tool](https://github.com/modular/modular/tree/main/max/kernels/benchmarks/autotune#readme) that helps you benchmark, autotune, and analyze Mojo kernel performance. ## Requirements - [NVIDIA GPU compatible with MAX](/max/packages#gpu-compatibility) - [NVIDIA driver version compatible with MAX](/max/packages#gpu-software-requirements) ### Verify your `nsys` installation Nsight Systems (`nsys`) is included with the CUDA Toolkit, but you can also install it separately. Log into your machine with an NVIDIA GPU and check if `nsys` is already available: ```bash which nsys ``` If not found, but you have CUDA installed, try adding the CUDA bin directory to your PATH. For example: ```bash export PATH=/usr/local/cuda/bin:$PATH ``` Otherwise, you can download it from [NVIDIA's Nsight Systems page](https://developer.nvidia.com/nsight-systems/get-started) or install it with your package manager. For example, on Ubuntu/Debian: ```bash sudo apt-get install nsight-systems ``` ## Profile a MAX model When profiling code built with MAX, you have to explicitly enable profiling by setting the `MODULAR_ENABLE_PROFILING` environment variable to `detailed` or by calling [`InferenceSession.gpu_profiling()`](/max/api/python/engine#max.engine.InferenceSession.gpu_profiling) before you load your model. In most cases, you'll probably use `MODULAR_ENABLE_PROFILING` so you don't have to modify your code. For example, you can profile `my_program.py` with this command: ```bash MODULAR_ENABLE_PROFILING=detailed \ nsys profile --trace=cuda,osrt,nvtx \ --cuda-memory-usage=true \ --output=profile \ --force-overwrite=true \ python my_program.py ``` If you don't have a MAX model to profile, you can try it with one of our code examples (requires [`pixi`](https://pixi.prefix.dev/latest/installation/)): 1. Clone the Modular repo and navigate to the examples: ```bash git clone https://github.com/modular/modular.git cd modular/max/examples/custom_ops ``` 2. Profile the top-k custom op example: ```bash MODULAR_ENABLE_PROFILING=detailed \ nsys profile --trace=cuda,osrt,nvtx \ --cuda-memory-usage=true \ --output=profile \ --force-overwrite=true \ pixi run top_k ``` This creates a `profile.nsys-rep` file in the current directory (you can change the filename and path with the `--output` option). To view the results, skip to the [view the profile](#view-the-profile) section. For details about the `nsys` command options, see the [Nsight Systems User Guide](https://docs.nvidia.com/nsight-systems/UserGuide/index.html). ## Profile a MAX endpoint To profile a MAX model endpoint, use `nsys launch` instead of `nsys profile`. This allows you to start the server and start profiling only when you're ready to run inference. For example, here's how to profile a benchmark workload: 1. Make sure you're in a project environment with [the `modular` package installed](/max/packages#install). 2. The next step deploys Google Gemma 3, so in order for MAX to download the model weights, you must first agree to the [Gemma 3 license](https://huggingface.co/google/gemma-3-12b-it) and set your [HF Access Token](https://huggingface.co/settings/tokens): ```bash export HF_TOKEN="hf_..." ``` 3. Start the MAX server ([`max serve`](/max/cli/serve)) with `nsys launch` (this doesn't start profiling yet): ```bash MODULAR_ENABLE_PROFILING=detailed \ numactl --cpunodebind=2 --membind=2 \ nsys launch \ --trace=cuda,nvtx,osrt \ --cuda-memory-usage=true \ --trace-fork-before-exec=true \ max serve --model google/gemma-3-12b-it ``` 4. Wait until the server is ready—you'll see a message like this in your terminal: ```output 🚀 Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` 5. Open a second terminal in the same environment and start profiling: ```bash nsys start \ --force-overwrite=true \ --output=server_profile \ --session=$(nsys sessions list -p false | awk '{print $1}') ``` You should see this message in the first terminal: ```output Collecting data... ``` 6. In the second terminal, use [`max benchmark`](/max/cli/benchmark) to run a benchmark workload: ```bash max benchmark \ --model google/gemma-3-12b-it \ --backend modular \ --endpoint /v1/chat/completions \ --dataset-name sonnet \ --num-prompts 500 \ --sonnet-input-len 550 \ --output-lengths 256 \ --sonnet-prefix-len 200 ``` 7. Once the benchmark completes, stop profiling from the second terminal: ```bash nsys stop --session=$(nsys sessions list -p false | awk '{print $1}') ``` This creates a `server_profile.nsys-rep` file in the current directory (the filename and path is specified by the `nsys start --output` option). ## View the profile The best way to inspect the profile results is to open the `.nsys-rep` file in the [Nsight Systems GUI application](https://developer.nvidia.com/nsight-systems). If you generated the profile on a local desktop environment, you can open the application with this command: ```bash nsys-ui profile.nsys-rep ``` If you generated the profile on a remote system, copy the `.nsys-rep` file to your local machine and open Nsight Systems locally: 1. [Install NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems/get-started) on your local system (available for Windows, macOS, and Linux). 2. Copy the `profile.nsys-rep` file to your local machine (such as via `scp`). 3. Double-click the `.nsys-rep` file to open it in Nsight Systems. You should see a timeline showing the profile results, similar to the example in figure 1.
Figure 1. Nsight Systems timeline showing profile results for a benchmark workload.
Alternatively, you can print a summary of the profile results in the terminal: ```bash nsys stats profile.nsys-rep ``` For detailed instructions for analyzing the profile results, see the [Nsight Systems Post-Collection Analysis Guide](https://docs.nvidia.com/nsight-systems/AnalysisGuide/index.html). ## Enable MAX profiling markers To enable MAX profiling markers, set the `MODULAR_ENABLE_PROFILING` environment variable to `detailed`, call the [`InferenceSession.gpu_profiling()`](/max/api/python/engine#max.engine.InferenceSession.gpu_profiling) method, or add the `--gpu-profiling` option to [`max serve`](/max/cli/serve). ### Enable profiling with an environment variable Here's an example of how to enable MAX profiling markers with the `MODULAR_ENABLE_PROFILING` environment variable: ```bash MODULAR_ENABLE_PROFILING=detailed \ nsys profile --trace=cuda,nvtx,osrt \ python my_program.py ``` `MODULAR_ENABLE_PROFILING` accepts the following values: - `off` (default): Disables MAX profiling markers. - `on`: Enables MAX profiling markers with NVTX markers for kernel correlation. - `detailed`: Enables MAX profiling markers with additional Python-level NVTX markers. ### Enable profiling with Python You can enable MAX profiling markers from Python by calling the [`InferenceSession.gpu_profiling()`](/max/api/python/engine#max.engine.InferenceSession.gpu_profiling) method before you load your model: ```python from max.engine import InferenceSession, GPUProfilingMode session = InferenceSession(devices=[GPU()]) session.gpu_profiling(GPUProfilingMode.DETAILED) ``` Beware that `gpu_profiling()` overrides the `MODULAR_ENABLE_PROFILING` environment variable if also used, and you must call it before `load()`. ### Enable profiling with a `max` CLI option You can also enable MAX profiling markers by adding the `--gpu-profiling` option to [`max serve`](/max/cli/serve). For example: ```bash max serve --gpu-profiling detailed \ --model google/gemma-3-12b-it ``` Beware that `--gpu-profiling` overrides the `MODULAR_ENABLE_PROFILING` environment variable if also used. ## NUMA binding with `numactl` Multi-socket systems (systems with 2 or more physical CPUs) use a non-uniform memory access (NUMA) architecture, in which each CPU can access the memory within other CPUs. However, it's always faster for a CPU to access the memory on its own CPU. So if you're profiling a program on a multi-socket system, you should bind the process to a specific CPU and memory node using [`numactl`](https://linux.die.net/man/8/numactl) to get more consistent profiling results. :::note It's safe to assume that any datacenter-grade server has a multi-socket system, with at least 2 physical CPUs, but you can confirm it with `numactl -H` or `lscpu | grep "Socket(s)"`. ::: First, install `numactl` if it's not already available: ```bash # Ubuntu/Debian sudo apt-get install numactl # RHEL/CentOS/Fedora sudo dnf install numactl ``` For example, you can bind your process to NUMA node 2 with this command: ```bash numactl --cpunodebind=2 --membind=2 \ nsys profile --trace=cuda,nvtx \ python my_program.py ``` The `--cpunodebind=2` option restricts the process to run only on CPUs in NUMA node 2, while `--membind=2` allocates memory only from that same node. You should adjust `--cpunodebind` and `--membind` based on your system topology—use `numactl -H` to view the system's NUMA topology. On single-CPU-socket systems, you can omit `numactl` entirely. ## See also - [Nsight Systems Post-Collection Analysis Guide](https://docs.nvidia.com/nsight-systems/AnalysisGuide/index.html): Detailed instructions for analyzing the `nsys` profile results. - [Kernel profiling with Nsight Compute](https://github.com/modular/modular/blob/main/max/docs/kernel-profiling.md): A companion to Nsight Systems for kernel-level profiling on NVIDIA GPUs. - [`kbench` tool](https://github.com/modular/modular/tree/main/max/kernels/benchmarks/autotune#readme): A Python-based toolkit that builds and executes Mojo kernel benchmarks across a grid of parameter combinations to help you autotune and analyze Mojo kernel performance on any hardware. --- ## Quantization (Graph) MAX allows you to load and run pre-quantized models through both its Python API and CLI. This guide explains quantization concepts and how to work with quantized models in your applications. For a complete list of supported quantization encodings, see the [quantization](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding) documentation. ## Understanding quantization Quantization reduces the numeric precision of model weights to decrease memory usage and increase inference speed. For example, models originally trained with `float32` weights can be represented using lower precision types like `int8` or `int4`, reducing each scalar value from 32 bits to 8 or 4 bits. When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision. ## How to load pre-quantized models with MAX You can load pre-quantized models using two primary approaches: - By specifying a path to a quantized weight file - By specifying the quantization encoding format for compatible models When you have a quantized weight file, you can load it directly using the `--weight-path` argument: ```bash max serve --model meta-llama/Llama-3.1-8B-Instruct \ --weight-path=bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf ``` MAX automatically detects the quantization format from the weight file. This approach works for models with standard quantization formats like GGUF and AWQ. For models that have been quantized using specific techniques but don't use a separate weight file format, you can specify the quantization encoding directly with the `--quantization-encoding` flag: ```bash max generate --model hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 \ --quantization-encoding=gptq \ --prompt "What is the meaning of life?" ``` The `--quantization-encoding` flag accepts the following values: - `float32`: Full precision 32-bit floating point. - `bfloat16`: Brain floating point 16-bit format. - `q4_0`: 4-bit quantization format. - `q4_k`: 4-bit quantization with K-means clustering. - `q6_k`: 6-bit quantization with K-means clustering. - `float8_e4m3fn`: 8-bit quantization with e4m3fn encoding. - `gptq`: Specialized quantization optimized for transformer-based models. For more information on the `max` CLI, see the [MAX CLI](/max/cli) documentation or the [MAX Serve API reference ](/max/api/serve). ## Quantized layer implementation For developers building custom models with the MAX Graph API you can implement custom quantized layers. This is useful when: - You're building a model from scratch using the MAX Graph API - You need precise control over how quantization is implemented - You're implementing specialized model architectures that require custom quantized operations To implement a quantized layer in Python, you'll need to make a few key changes compared to a standard linear layer. Let's look at the differences. A standard linear layer in MAX might look like this: ```python from max import nn from max.dtype import DType from max.graph import DeviceRef, Weight class Linear(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.weight = Weight( name="weight", dtype=DType.float32, shape=[in_dim, out_dim], device=DeviceRef.CPU(), ) self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim]) def __call__(self, x): return x @ self.weight.T.to(x.device) + self.bias.to(x.device) ``` To enable support for GGUF quantization like [`Q4_0`](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q4_0), [`Q4_K`](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q4_K), or other encodings, you need to: 1. Load weights from the quantized model checkpoint as `uint8` with the appropriate shape. 2. Replace the standard matrix multiplication `(@)` with the [`qmatmul`](/max/api/python/graph/ops#max.graph.ops.qmatmul) operation. 3. Specify the quantization encoding to use. Here's how you might implement a quantized linear layer: ```python from max import nn from max.dtype import DType from max.graph import DeviceRef, Weight, ops from max.graph.quantization import QuantizationEncoding class QuantizedLinear(nn.Module): def __init__(self, in_dim, out_dim, quantization_encoding): super().__init__() self.weight = Weight( name="weight", # The DType must be uint8. dtype=DType.uint8, # This shape must be updated to match the quantized shape shape=[in_dim, out_dim], device=DeviceRef.CPU(), quantization_encoding=quantization_encoding, ) self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim]) def __call__(self, x): return ops.qmatmul( self.weight.quantization_encoding, None, x, self.weight.to(x.device) ) + bias.to(x.device) quantized_linear = QuantizedLinear(in_dim, out_dim, QuantizationEncoding.Q4_0) ``` The [MAX graph quantization](/max/api/python/graph/quantization) class defines the available quantization formats supported by MAX. These encodings include: - [Q4_0](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q4_0): 4-bit quantization format - [Q4_K](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q4_K): 4-bit quantization with K-means clustering - [Q5_K](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q5_K): 5-bit quantization with K-means clustering - [Q6_K](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.Q6_K): 6-bit quantization with K-means clustering - [GPTQ](/max/api/python/graph/quantization#max.graph.quantization.QuantizationEncoding.GPTQ): Specialized quantization optimized for transformer-based models With this implementation, you can add quantized weights into your MAX models. The [`qmatmul`](/max/api/python/graph/ops#max.graph.ops.qmatmul) operation handles the dequantization process during inference, giving you the performance benefits of quantization without having to manage the low-level details. --- ## Embeddings import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import SmallCards from '@site/src/components/SmallCards'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; import InstallOpenAI from '@site/docs/_includes/install-openai.mdx'; import MDXListing from '@site/src/components/Listing/MDXListing'; import Requirements from '@site/src/components/Requirements'; import { requirementsNoGPU } from '@site/docs/max/requirements'; Text embeddings are rich numerical representations of text. They capture semantic meaning in a way that allows computers to compare, cluster, and search text effectively. Use embeddings whenever you need to measure similarity between pieces of text, perform semantic search, build recommendation systems, or cluster documents. They are foundational for many modern NLP tasks. In contemporary GenAI applications, embeddings are especially powerful in agentic workflows, including: - **Retrieval-Augmented Generation (RAG):** Embeddings make it possible to store and search large collections of documents, grounding model responses in your own data instead of relying only on a model's training knowledge. - **Context injection for agents:** Embeddings help agents decide which pieces of external knowledge (APIs, tools, or documents) are most relevant to the current query. - **Personalization and recommendations:** By embedding both user data and content, systems can deliver more tailored results. - **Clustering and analytics:** Embeddings allow grouping similar inputs for downstream tasks like summarization, deduplication, and insight extraction. ## Endpoint MAX supports the [`v1/embeddings`](/max/api/serve#operation/createEmbedding) endpoint, which is fully compatible with the OpenAI API. To use the endpoint, provide the ID of an embedding model along with the text to embed. The API returns numerical embeddings that capture the semantic meaning of each input. The request payload should look similar to the following: ```json { "model": "sentence-transformers/all-mpnet-base-v2", "input": "The food was delicious and the service was excellent." } ``` ## Quickstart Serve and interact with an embedding model using an OpenAI-compatible endpoint. Specifically, we'll use MAX to serve the [all-mpnet-base-v2](https://builds.modular.com/models/all-mpnet-base-v2/5B) model, which is a powerful transformer that excels at capturing semantic relationships in text. System requirements: ### Set up your environment Create a Python project to install our APIs and CLI tools: ### Serve your model Use the [`max serve`](/max/cli/serve) command to start a local model server with the [all-mpnet-base-v2](https://builds.modular.com/models/all-mpnet-base-v2/5B) model: ```sh max serve \ --model sentence-transformers/all-mpnet-base-v2 ``` This will create a server running the `all-mpnet-base-v2` embedding model on `http://localhost:8000/v1/embeddings`, an [OpenAI compatible endpoint](https://platform.openai.com/docs/api-reference/embeddings). The endpoint is ready when you see this message printed in your terminal: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` For a complete list of `max` CLI commands and options, refer to the [MAX CLI reference](/max/cli). ### Interact with your model MAX supports OpenAI's REST APIs and you can interact with the model using either the OpenAI Python SDK or curl: You can use OpenAI's Python client to interact with the model. First, install the OpenAI API: Then, create a client and make a request to the model: ```python from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") # Create embeddings response = client.embeddings.create( model="sentence-transformers/all-mpnet-base-v2", input="Run an embedding model with MAX Serve!", ) print(f"{response.data[0].embedding[:5]}") ``` You should receive a response similar to this: ```json {"data":[{"index":0,"embedding":[-0.06595132499933243,0.005941616836935282,0.021467769518494606,0.23037832975387573, ``` The text has been shortened for brevity. This returns a numerical representation of the input text that can be used for semantic comparisons. The following `curl` command sends an embeddings request to the model: ```sh curl http://localhost:8000/v1/embeddings \ -H "Content-Type: application/json" \ -d '{ "input": "Run an embedding model with MAX Serve!", "model": "sentence-transformers/all-mpnet-base-v2" }' ``` You should receive a response similar to this: ```json {"data":[{"index":0,"embedding":[-0.06595132499933243,0.005941616836935282,0.021467769518494606,0.23037832975387573, ``` The text has been shortened for brevity. This returns a numerical representation of the input text that can be used for semantic comparisons. For complete details on all available API endpoints and options, see the [REST API documentation](/max/api/serve). ## Next steps Now that you have successfully set up MAX with an OpenAI-compatible embeddings endpoint, checkout out these other tutorials: export const docs = [ '../../develop/build-an-mlp-block.mdx', '../../develop/build-custom-ops.mdx', '../../develop/custom-ops-matmul.mdx', ]; --- ## Image to text import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import SmallCards from '@site/src/components/SmallCards'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; import InstallOpenAI from '@site/docs/_includes/install-openai.mdx'; import MDXListing from '@site/src/components/Listing/MDXListing'; import Requirements from '@site/src/components/Requirements'; import { requirementsWithGPU } from '@site/docs/max/requirements'; Multimodal large language models are capable of processing images and text together in a single request. They can describe visual content, answer questions about images, and support tasks such as image captioning, document analysis, chart interpretation, optical character recognition (OCR), and content moderation. ## Endpoint You can interact with a multimodal LLM through the [`v1/chat/completions`](/max/api/serve#operation/createChatCompletion) endpoint by including image inputs alongside text in the request. This allows you to provide an image URL or base64-encoded image as part of the conversation, enabling use cases such as image captioning, asking questions about a photo, requesting a chart summary, or combining text prompts with visual context. ### URL input Within the `v1/chat/completions` request body, the `"messages"` array accepts inline image URLs. For example: ```json "messages": [ { "role": "user", "content": [ { "type": "text", "text": "What is in this image?" }, { "type": "image_url", "image_url": { "url": "https://example.com/path/to/image.jpg" } } ] } ] ``` ### Local file input To use local images, you must configure allowed directories before starting the server. This prevents unauthorized file access by restricting which paths the server can read from. Set the `MAX_SERVE_ALLOWED_IMAGE_ROOTS` environment variable to a JSON-formatted list of allowed directories: ```bash export MAX_SERVE_ALLOWED_IMAGE_ROOTS='["/path/to/images"]' ``` Then reference files with an absolute path: ```json "messages": [ { "role": "user", "content": [ { "type": "text", "text": "What is in this image?" }, { "type": "image_url", "image_url": { "url": "file:///path/to/images/image.jpg" } } ] } ] ``` The file path must be within a directory listed in `MAX_SERVE_ALLOWED_IMAGE_ROOTS`. If no allowed roots are configured, all `file:///` requests return a 400 error. The maximum file size is 20 MiB by default, which you can adjust by setting the `MAX_SERVE_MAX_LOCAL_IMAGE_BYTES` environment variable to a value in bytes. ## Quickstart In this quickstart, learn how to set up and run [Gemma 3 27B Instruct](https://builds.modular.com/models/gemma-3-it/27B), which excels at tasks such as image captioning and visual question answering. :::caution GPU required To run Gemma 3 27B Instruct, your system must have a [compatible GPU](/max/packages#gpu-compatibility) with >60 GiB of GPU RAM. Due to the model's memory requirements, we recommend an NVIDIA B200, H200, or AMD MI355X. ::: System requirements: ### Set up your environment Create a Python project to install our APIs and CLI tools: ### Serve your model Agree to the [Gemma 3 license](https://huggingface.co/google/gemma-3-27b-it) and make your Hugging Face [access token](https://huggingface.co/settings/tokens) available in your environment: ```bash export HF_TOKEN="hf_..." ``` Then, use the [`max serve`](/max/cli/serve) command to start a local model server with the Gemma 3 27B Instruct model: ```bash max serve \ --model google/gemma-3-27b-it ``` :::note You may need to specify the `--max-length` and `--max-batch-size` parameters depending on the amount of memory you have access to. ::: This will create a server running the `google/gemma-3-27b-it` multimodal model on `http://localhost:8000/v1/chat/completions`, an [OpenAI compatible endpoint](https://platform.openai.com/docs/guides/images-vision). While this example uses the Gemma 3 27B Instruct model, you can replace it with any of the vision models listed in our [model repository](https://builds.modular.com/?category=models&modality=Vision). The endpoint is ready when you see this message printed in your terminal: ```bash Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` For a complete list of `max` CLI commands and options, refer to the [MAX CLI reference](/max/cli). ### Interact with your model Open a new terminal window, navigate to your project directory, and activate your virtual environment. MAX supports OpenAI's REST APIs and you can interact with the model using either the OpenAI Python SDK or curl: You can use OpenAI's Python client to interact with the vision model. First, install the OpenAI API: Then, create a client and make a request to the model: ```python title="generate-image-description.py" from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") response = client.chat.completions.create( model="google/gemma-3-27b-it", messages=[ { "role": "user", "content": [ { "type": "text", "text": "What is in this image?" }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], max_tokens=300 ) print(response.choices[0].message.content) ``` In this example, you're using the OpenAI Python client to interact with the MAX endpoint running on local host `8000`. The `client` object is initialized with the base URL `http://0.0.0.0:8000/v1` and the API key is ignored. When you run this code, the model should respond with information about the image: ```sh python generate-image-description.py ``` ```output Here's a breakdown of what's in the image: * **Peter Rabbit:** The main focus is a realistic-looking depiction of Peter Rabbit, the character from Beatrix Potter's stories... ``` You can send requests to the local endpoint using `curl`. The following request includes an image URL and a question to answer about the provided image: ```bash curl -N http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-3-27b-it", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "What is in this image?" }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], "max_tokens": 300 }' | grep -o '"content":"[^"]*"' | sed 's/"content":"//g' | sed 's/"//g' | tr -d '\n' | sed 's/\\n/\n/g' ``` This sends an object location to an image along with a text prompt to the model. You should receive a response similar to this: ```output Here's a breakdown of what's in the image: * **Peter Rabbit:** The main focus is a realistic, anthropomorphic (human-like) rabbit character... ``` :::note When making requests with `max serve`, you do not need to include model-specific image tags within your prompt. ::: For complete details on all available API endpoints and options, see the [MAX Serve API documentation](/max/api/serve). ## Next steps Now that you can analyze images, try adding structured output to get consistent, formatted responses. You can also explore other endpoints and deployment options. export const docs = [ '../../serve/structured-output.mdx', '../../inference/embeddings.mdx', '../../deploy/local-to-cloud.mdx', ]; --- ## Text to text import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; import InstallOpenAI from '@site/docs/_includes/install-openai.mdx'; import Requirements from '@site/src/components/Requirements'; import { requirementsNoGPU } from '@site/docs/max/requirements'; MAX makes it easy to generate text with large language models, whether for conversational applications, single-turn prompts, or offline inference workflows. MAX text completion endpoints are fully compatible with the OpenAI API, so you can use familiar tools and libraries. Text completions let you instruct a model to produce new text based on a prompt or an ongoing conversation. They can be used for a wide range of tasks, including writing content, generating synthetic data, building chatbots, or powering multi-turn assistants. MAX provides two main endpoints for text completions: [`v1/chat/completions`](/max/inference/text-to-text#v1chatcompletions) and [`v1/completions`](/max/inference/text-to-text#v1completions). ## Endpoints The [`v1/chat/completions`](/max/api/serve#operation/createChatCompletion) endpoint is recommended as the default for most text use cases. It supports both single-turn and multi-turn scenarios. The `v1/completions` endpoint is also supported for traditional single-turn text generation tasks, which is useful for offline inference or generating text from a prompt without conversational context. ### `v1/chat/completions` The [`v1/chat/completions`](/max/api/serve#operation/createChatCompletion) endpoint is designed for chat-based models and supports both single-turn and multi-turn interactions. You provide a sequence of structured messages with roles (`system`, `user`, `assistant`), and the model generates a response. For example, within the `v1/chat/completions` request body, the `"messages"` array might look similar to the following: ```json "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "Who won the world series in 2020?" } ] ``` Use a combination of roles to give the model the context it needs. A `system` message can define overall model response behavior, `user` messages represent instructions or prompts from the end-user interacting with the model, and `assistant` messages are a way to incorporate past model responses into the message context. Use this endpoint whenever you want conversational interaction, such as: - Building chatbots or assistants - Implementing Q&A systems - Supporting multi-turn dialogue in applications It's also fully compatible with single-turn use cases, making it versatile enough for general text generation workflows. ### `v1/completions` The [`v1/completions`](/max/api/serve#operation/createCompletion) endpoint supports traditional text completions. You provide a prompt, and the model returns generated text. This endpoint is ideal when you only need a single response per request, such as: - Offline inference workflows - Synthetic text generation - One-off text generation tasks ## Quickstart Get started quickly serving Gemma 3 locally with the `max` CLI and interact with it through the MAX REST and Python APIs. You'll learn to configure the server and make requests using the OpenAI client libraries as a drop-in replacement. System requirements: ### Set up your environment Create a Python project to install our APIs and CLI tools: ### Serve your model Use the [`max serve`](/max/cli/serve) command to start a local server with the Gemma 3 model: ```bash max serve --model google/gemma-3-12b-it ``` This creates a server running the `google/gemma-3-12b-it` large language model on `http://localhost:8000/v1/chat/completions`, an [OpenAI compatible endpoint](https://platform.openai.com/docs/api-reference/chat). While this example uses the Gemma 3 model, you can replace it with any of the models listed in the [MAX Builds](https://builds.modular.com/?category=models) site. :::note When searching for a model using the MAX Builds site, ensure that the model type can fit into memory of your machine. You can filter and sort models by hardware type, and size of the model. For more information and to learn how to use the MAX Builds site, see [MAX Builds in 60 seconds](https://www.youtube.com/watch?v=EqM1TB1GgCc). ::: The endpoint is ready when you see this message printed in your terminal: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` For a complete list of `max` CLI commands and options, refer to the [MAX CLI reference](/max/cli). ### Generate a text chat completion MAX supports OpenAI's REST APIs and you can interact with the model using either the OpenAI Python SDK or curl: You can use OpenAI's Python client to interact with the model. First, install the OpenAI API: Then, create a client and make a request to the model: ```python title="generate-text.py" from openai import OpenAI client = OpenAI( base_url = 'http://0.0.0.0:8000/v1', api_key='EMPTY', # required by the API, but not used by MAX ) response = client.chat.completions.create( model="google/gemma-3-12b-it", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, {"role": "assistant", "content": "The LA Dodgers won in 2020."}, {"role": "user", "content": "Where was it played?"} ] ) print(response.choices[0].message.content) ``` In this example, you're using the OpenAI Python client to interact with the MAX endpoint running on local host `8000`. The `client` object is initialized with the base URL `http://0.0.0.0:8000/v1` and the API key is ignored. When you run this code, the model should respond with information about the 2020 World Series location: ```sh python generate-text.py ``` ```output The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was a neutral site due to the COVID-19 pandemic. ``` The following `curl` command sends a chat request to the model's chat completions endpoint: ```bash curl http://0.0.0.0:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-3-12b-it", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "Hello, how are you?" } ], "max_tokens": 100 }' ``` You should receive a response similar to this: ```json { "id": "18b0abd2d2fd463ea43efe2c147bcac0", "choices": [ { "finish_reason": "stop", "index": 0, "message": { "content": " I'm doing well, thank you for asking. How can I assist you today?", "refusal": "", "tool_calls": null, "role": "assistant", "function_call": null }, "logprobs": { "content": [], "refusal": [] } } ], "created": 1743543698, "model": "google/gemma-3-12b-it", "service_tier": null, "system_fingerprint": null, "object": "chat.completion", "usage": { "completion_tokens": 17, "prompt_tokens": null, "total_tokens": 17 } } ``` For complete details on all available API endpoints and options, see the [REST API documentation](/max/api/serve). ## Next steps Now that you have successfully set up MAX with an OpenAI-compatible chat endpoint, check out additional serving optimizations specific to your use case. export const docs = [ '../../serve/structured-output.mdx', '../../serve/prefix-caching.mdx', '../../serve/offline-inference.mdx' ] --- ## What is Modular import { Button } from '@mantine/core'; import DocLink from '@site/src/components/DocLink'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import ContactSection from '@site/src/components/ContactSection'; The Modular Platform is an open and fully-integrated suite of AI libraries and tools that accelerates model serving and scales GenAI deployments. It abstracts away hardware complexity so you can run the most popular open models with industry-leading GPU and CPU performance without any code changes. Our ready-to-deploy Docker container removes the complexity of deploying your own GenAI endpoint. And unlike other serving solutions, Modular enables customization across the entire stack. You can customize everything from the serving pipeline and model architecture all the way down to the metal by writing custom ops and GPU kernels in Mojo. Most importantly, Modular is hardware-agnostic and free from vendor lock-in—no CUDA required—so your code runs seamlessly across diverse systems. It takes only a moment to start an OpenAI-compatible endpoint with a model from Hugging Face: ```sh max serve --model google/gemma-3-27b-it ``` ```sh docker run --gpus=1 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/max_cache:/opt/venv/share/max/.max_cache \ -p 8000:8000 \ modular/max-nvidia-full:latest \ --model-path google/gemma-3-27b-it ``` ```python from openai import OpenAI client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="EMPTY") completion = client.chat.completions.create( model="google/gemma-3-27b-it", messages=[ { "role": "user", "content": "Write a one-sentence bedtime story about a unicorn.", }, ], ) print(completion.choices[0].message.content) ``` ## Capabilities - [x] **High-performance, portable serving**: Serve 500+ AI models from Hugging Face using our OpenAI-compatible REST API with industry-leading performance across GPUs and CPUs. - [x] **Large-scale, GenAI deployment**: Scale massive GenAI inference services across thousands of GPU nodes. Modular intelligently routes workloads across models and hardware types to maximize throughput and minimize latency. - [x] **Flexible, faster development**: Deploy with a single Docker container that's under 1GB across multiple hardware types, compile in seconds rather than hours, and develop faster with a slim toolchain that makes versioning and dependency nightmares disappear. - [x] **Customize everywhere**: Customize at any layer of the stack by writing hardware-agnostic GPU and CPU kernels, porting models into Modular's optimized graph format, or programming hardware directly with Mojo (no hardware-specific libraries). ## Components Modular is a vertically integrated AI infrastructure stack that spans from the hardware all the way up to Kubernetes, and it provides entry points for users at every level.
Figure 1. A simplified diagram of how the Modular Platform scales your GenAI deployment.
- 🦣 **Mammoth**: A Kubernetes-native control plane, router, and substrate specially-designed for large-scale distributed AI serving. It supports multi-model management, prefill-aware routing, disaggregated compute and cache, and other advanced AI optimizations. - 🧑🏻‍🚀 **MAX**: A high-performance AI serving framework that includes advanced model serving optimizations like speculative decoding, and graph compiler optimizations like op-level fusions. It provides an OpenAI-compatible serving endpoint, executes native MAX and PyTorch models across GPUs and CPUs, and can be customized at the model and kernel level. - 🔥 **Mojo**: A kernel-focused systems programming language that enables high-performance GPU and CPU programming, blending Pythonic syntax with the performance of C/C++ and the safety of Rust. All the kernels in MAX are written with Mojo and it can be used to extend MAX Models with novel algorithms. ## Get started You can create an OpenAI-compatible REST endpoint using our `max` CLI or our Docker container: - [**Start with pip**](/max/get-started): Install MAX with `pip` and run inference with Python or a REST endpoint. - [**Start with Docker**](/max/container): Run our Docker container to create a REST endpoint. In either case, you can select from hundreds of GenAI models in our [Model repository](https://builds.modular.com/?category=models). You can also load weights from Hugging Face or load your own fine-tuned weights. For performance optimization, you can port models from PyTorch to MAX using the [MAX Graph API](/max/develop/get-started-with-max-graph-in-python). For deeper customization, you can extend MAX Models with [custom operations](/max/develop/build-custom-ops) (ops) written in Mojo. Your custom ops are automatically analyzed and fused into the model graph, delivering low-level acceleration without sacrificing developer productivity. :::note Mammoth powers advanced routing and scaling capabilities behind the scenes for Modular's Dedicated Endpoint and Enterprise [editions](https://www.modular.com/pricing). [Contact us](https://www.modular.com/request-demo) to learn how Mammoth can help scale your workloads. ::: ## Stay in touch --- ## Model support import MDXListing from '@site/src/components/Listing/MDXListing'; MAX allows you to pick the perfect GenAI for your project from Hugging Face. You just provide the name of the model you want, and MAX takes care of the rest. It builds the model as a high-performance graph and starts a serving endpoint that runs the model on either a CPU or GPU. This page explains how this works out of the box with models from Hugging Face, and introduces how you can customize an existing model or create your own. :::note MAX model repo If you just want to browse some models, check out the [MAX model repository](https://builds.modular.com/?category=models&type=MAX+Model). ::: ## Model configs To understand how MAX accelerates hundreds of GenAI models from Hugging Face, you should first know a little about Hugging Face model configurations. Nowadays, the definitive place to find AI models is [Hugging Face Model Hub](https://huggingface.co/models). Although models on Hugging Face might be built and trained with different machine learning frameworks, they all include a `config.json` file, which is like a model blueprint. This file contains all the information you need to understand the model architecture and its configuration, such as the number of layers used, the embedding size, and other hyperparameters. By reading the model configuration, we can reconstruct any model from Hugging Face as a MAX model. ## MAX models {#max-graph} A MAX model is a high-performance inferencing model built with our [MAX Python API](/max/api/python/). It's a unique model format that allows the MAX graph compiler to optimize the model for inference on a wide range of hardware and deliver state-of-the-art performance you normally see only from model-specific inference libraries written in C or C++. You can build these models yourself with our Python API, but you don't have to. All you have to do is specify the GenAI model you want from Hugging Face (such as [`meta-llama/Llama-3.2-1B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)), and MAX will programmatically reconstruct it as a MAX model. This works because we have already built a library of [base model architectures](https://github.com/modular/modular/tree/main/max/python/max/pipelines/architectures) with the MAX Python API. When you ask MAX to start an inference server with a Hugging Face model, MAX pulls the corresponding pre-built architecture from our library and makes the appropriate changes based on the configuration from Hugging Face. This all happens automatically when you start a serving endpoint with the [`max`](/max/cli) CLI or with the [MAX container](/max/container). For example, here's how to start an endpoint using Meta's Llama 3.2 Instruct model as a MAX model: ```sh max serve --model meta-llama/Llama-3.2-1B-Instruct ``` :::caution This model requires a GPU The command above will fail if your system doesn't have a [compatible GPU](/max/packages#gpu-compatibility). However, you can make it work if you instead [load quantized weights](#customize-a-model) as shown below. ::: When you run the `max serve` command, MAX pulls the model configuration and weights from Hugging Face and builds it as a MAX model. Then it starts up an endpoint to handle inference requests that you send using [our REST API](/max/api/serve). ### Customize a model If you want to load a different set of weights for a given model, you can pass them in GGUF or Safetensors format using the `--weight-path` argument. This accepts either a local path or a Hugging Face repo with the weights. For example, here's how to run `Llama-3.2-1B-Instruct` on a CPU with quantized weights ([from bartowski](https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF)): ```sh max serve --model meta-llama/Llama-3.2-1B-Instruct \ --weight-path=bartowski/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct-Q6_K.gguf ``` When using GGUF models, quantization encoding formats are automatically detected. When using the `max` command with a model from a Hugging Face repository, explicitly providing a quantization encoding is optional. ```sh max serve --model "modularai/Llama-3.1-8B-Instruct-GGUF" \ --quantization-encoding=q4_k ``` If no quantization encoding is specified, MAX Serve automatically detects and uses the first encoding option from the repository. If a quantization encoding is provided, it must align with the available encoding options in the repository. If the repository contains multiple quantization formats, be sure to specify which encoding type you want to use. For help creating your own weights in GGUF format, see the tutorial to [Bring your own fine-tuned model](/max/develop/max-pipeline-bring-your-own-model). For more information on quantization, see the [Quantization](/max/graph/quantize) documentation. ### Build your own model Although our model-building APIs are still under heavy development while we implement the most popular architectures, you can also build your own models with the MAX APIs today. To build your own inferencing model with the MAX, the process generally looks like this: 1. Instantiate a [`Graph`](/max/api/python/graph/Graph) by specifying the input shape as a [`TensorType`](/max/api/python/graph/type#max.graph.type.TensorType). 2. Build the graph by chaining [`ops`](/max/api/python/graph/ops/) functions. Each function takes and returns a [`Value`](/max/api/python/graph/Value) object. 3. Add the final `Value` to the graph using the [`output()`](/max/api/python/graph/Graph#max.graph.Graph.output) method. For more information, see our tutorial to [get started with MAX Graph in Python](/max/develop/get-started-with-max-graph-in-python). ## Get started export const docs = [ '../deploy/local-to-cloud.mdx', ]; --- ## Packages import InstallModular from '@site/docs/_includes/install-modular.mdx'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import Requirements from '@site/src/components/Requirements'; import { requirementsNoGPU } from '@site/docs/max/requirements'; All the Modular APIs and tools (including MAX and Mojo) are included in a single Python/Conda package named `modular`. This page explains what's in the package and shows how to install it. If you just want to get started with Modular, instead see our [quickstart guide](/max/get-started). :::note **New!** You can now install a standalone [`mojo` package](/mojo/manual/install). ::: ## Package managers You can install `modular` as a Python or Conda package. The install instructions below include commands for popular package managers such as `pixi`, `uv`, `pip`, and `conda`, but you can also use something else. Traditional tools like `pip` and `conda` might be necessary for existing production environments or Docker containers, but for your local development, we recommend using [Pixi](https://pixi.sh/latest/). Pixi is both a package manager and virtual environment manager, which alone makes development a lot easier, but it's also language agnostic, extremely fast, and includes lock files to easily reproduce your project environment. You'll notice that our [GitHub code examples](https://github.com/modular/modular/tree/main/max/examples) include a `pixi.toml` file. This file configures the environment to make sure we all use the same packages and get the same results—you just need to install `pixi`. So if you're not set on using a particular package manager, we suggest you try `pixi`. If you haven't used it before, check out our [Pixi basics guide](/pixi) or the [official Pixi docs](https://pixi.sh/latest/). ## Install To get the latest improvements and new features, we recommend installing our nightly build, which we release several times a week. If you want a better tested but older version, you can install a stable build. (Each release is described in the [changelog](/max/changelog).) System requirements: :::tip We recommend installing with `pixi` for the most reliable experience. ::: The `modular` package installs MAX, Mojo, and other package dependencies. :::note GitHub stable branch If you're using a stable build and want to clone the [Modular repo](https://github.com/modular/modular), make sure you checkout the `stable` branch (because the `main` branch includes the latest nightly code). For example: ```sh git clone -b stable https://github.com/modular/modular.git ``` ::: ## Uninstall You can uninstall `modular` from your virtual environment with this command: ```sh pixi remove modular ``` To deactivate your virtual environment, run: ```sh exit ``` You can uninstall `modular` from your virtual environment with the following command: ```sh uv pip uninstall modular ``` To deactivate your virtual environment, run: ```sh deactivate ``` You can uninstall `modular` from your virtual environment with the following command: ```sh pip uninstall modular ``` To deactivate your virtual environment, run: ```sh deactivate ``` You can uninstall `modular` from your virtual environment with this command: ```sh conda remove modular ``` To deactivate your virtual environment, run: ```sh deactivate ``` ## What's included Here's a summary of what's in the `modular` package. The `modular` Conda package installs the following: - MAX tools and libraries: - [`max` CLI](/max/cli) - [`max` Python library](/max/api/python/) - [`max` Mojo library](/mojo/lib) - [MAX Engine C API](/max/api/c/) - The `mojo` package: - [`mojo` CLI](/mojo/cli) (includes the Mojo compiler) - [Mojo standard library](/mojo/lib) - Mojo language server (LSP) for IDE/editor integration - [Mojo debugger](/mojo/tools/debugging) (includes LLDB) - [Mojo code formatter](/mojo/cli/format) - [Mojo REPL](/mojo/cli/repl) `pixi` known issues: - You might encounter issues if you invoke `pixi` within a `conda` virtual environment. It's best if you don't mix the two tools. The `modular` Python wheel installs the following: - MAX tools and libraries: - [`max` CLI](/max/cli) - [`max` Python library](/max/api/python/) - [`max` Mojo library](/mojo/lib) - The `mojo` package: - [`mojo` CLI](/mojo/cli) (includes the Mojo compiler) - [Mojo standard library](/mojo/lib) - Mojo language server (LSP) for IDE/editor integration - [Mojo debugger](/mojo/tools/debugging) (includes LLDB) - [Mojo code formatter](/mojo/cli/format) - [Mojo REPL](/mojo/cli/repl) :::note Note For information about the `mojo` package, see the [Mojo install guide](/mojo/manual/install). ::: ## System requirements To install `modular`, your system must meet these specifications. - macOS Sequoia (15) or later - Apple silicon (M1 - M5 processor) - Python 3.10 - 3.14 - Xcode or Xcode Command Line Tools 16 or later - We have [limited compatibility](#gpu-compatibility) for Apple Silicon GPUs. You may need to run `xcodebuild -downloadComponent MetalToolchain`, which downloads the Metal utilities required for GPU programming in later versions of Xcode. - Ubuntu 22.04 LTS - x86-64 CPU (with [SSE4.2 or newer](https://www.intel.com/content/www/us/en/support/articles/000057621/processors.html)) or AWS Graviton2/3 CPU - Minimum 8 GiB RAM (or much more, depending on the model you run) - Python 3.10 - 3.14 - g++ or clang++ C++ compiler - To use GPUs, see the [GPU compatibility](#gpu-compatibility) Windows is not officially supported at this time. In the meantime, you can try MAX on Windows [with WSL](https://learn.microsoft.com/en-us/windows/wsl/install), using a compatible version of Ubuntu (see our requirements for Linux). ### GPU compatibility import GpuRequirements from '@site/docs/mojo/_includes/gpu-requirements.mdx'; The Modular Platform supports both CPUs and GPUs, so you don't always need a GPU to serve a model or program with Mojo. But if you do want to accelerate your model with GPUs or program for GPUs with Mojo, we support the following GPUs. :::note Notes - Many GPUs are available in variants with different amounts of memory, and each AI model has different memory requirements. So even if your GPU architecture is listed as compatible, you must confirm that the available memory is sufficient for the model you're using. - Modular can serve lots of models on either CPU and GPU, but some models do require one or more GPUs. When you browse our [model repository](https://builds.modular.com/?category=models), you can filter by models that support either CPU or GPU. ::: --- ## Function calling and tool use import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; import InstallOpenAI from '@site/docs/_includes/install-openai.mdx'; Function calling is a feature available with some large language models (LLMs) that allows them to call external program functions (or tools). This allows the model to interact with external systems to retrieve new data for use as input or execute other tasks. This is a foundational building block for agentic AI applications, in which an LLM can chain together various functions to achieve complex objectives. Function calling is also called "tool use" because the manner in which you tell the LLM what functions are available is with a `tools` parameter in the request body. :::note Function calling is enabled by default with MAX, but its availability is model-dependent and will produce valid output only if the model is pretrained to return tool-use responses. ::: ## When to use function calling You should use function calling when you want your LLM to: - **Fetch data**: Such as fetch weather data, stock prices, or news updates from a database. The model will call a function to get information, and then incorporate that data into its final response. - **Perform actions**: Such as modify application states, invoke workflows, or call upon other AI systems. The model will call another tool to perform an action, effectively handing off the request after it determines what the user wants. ## How function calling works When you send an inference request to a model that supports function calling, you can specify which functions are available to the model using the `tools` body parameter. The `tools` parameter provides information that allows the LLM to understand: - What each function can do - How to call each function (the arguments it accepts/requires) For example, here's a request with the [chat completions API](/max/api/serve#operation/createChatCompletion) that declares an available function named `get_weather()`: ```python from openai import OpenAI def get_weather(city: str) -> str: print("Get weather:", city) client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="EMPTY") tools = [{ "type": "function", "function": { "name": "get_weather", "description": "Get current temperature for a given location.", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "City and country e.g. Bogotá, Colombia" } }, "required": [ "location" ], "additionalProperties": False }, "strict": True } }] messages = [ { "role": "user", "content": "What's the weather like in San Francisco today?" } ] completion = client.chat.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", messages=messages, tools=tools ) ``` Let's take a closer look at each parameter shown in the `tools` property: - `type`: Currently this is always `function` - `function`: Definition of the function - `name`: The function name used by the LLM to call it - `description`: A function description that helps the LLM understand when to use it - `parameters`: Definition of the function parameters - `type`: Defines this as an object containing parameters - `properties`: Lists all possible function arguments and their types - `required`: Specifies which function arguments are required This format follows the [OpenAI function calling specification](https://platform.openai.com/docs/guides/function-calling) to specify functions as tools that a model can use. Using this information, the model will decide whether to call any functions specified in `tools`. In this case, we expect the model to call `get_weather()` and incorporate that information into its final response. So, the initial `completion` response from above includes a `tool_calls` parameter like this: ```python print(completion.choices[0].message.tool_calls) ``` ```js [ChatCompletionMessageToolCall( id='call_a175692d9ff54554', function=Function( arguments='{ "location": "San Francisco, USA" }', name='get_weather' ), type='function' )] ``` From here, you must parse the `tool_calls` body and execute the function as appropriate. For example: ```py import json tool_call = completion.choices[0].message.tool_calls[0] args = json.loads(tool_call.function.arguments) result = get_weather(args["location"]) ``` If the function is designed to **fetch data** for the model, you should call the function and then call the model again with the function results appended as a message using the `tool` role. If the function is designed to **perform an action**, then you don't need to call the model again. For detail about how to execute the function and feed the results back to the model, see the [OpenAI docs about handling function calls](https://platform.openai.com/docs/guides/function-calling?api-mode=chat&example=get-weather#handling-function-calls). The OpenAI function calling spec is compatible with multiple agent frameworks, such as [AutoGen](https://github.com/microsoft/autogen), [CrewAI](https://github.com/crewAIInc/crewAI), and more. :::caution MAX currently doesn't support [streaming with function calling](https://platform.openai.com/docs/guides/function-calling?api-mode=chat&example=get-weather#streaming). If using a model that provides streaming, be sure to set the `stream` parameter to `False` when making requests with function calling. ::: ## Supported models Function calling is model-dependent and will produce valid output only if the model is pretrained to return tool use responses. Here are just a few that we've verified work with function calling: - [Meta's Llama 3.1 models & evals](https://huggingface.co/collections/meta-llama/metas-llama-31-models-and-evals-675bfd70e574a62dd0e40565) collection - [Meta's Llama 3.2 language models & evals](https://huggingface.co/collections/meta-llama/metas-llama-32-language-models-and-evals-675bfd70e574a62dd0e40586) collection :::note The Meta Llama 3 models are hosted in gated repositories on Hugging Face. You must have a Hugging Face account with access to these repositories and an [access token](https://huggingface.co/settings/tokens) configured in your environment to deploy these models. ::: ## Quickstart Here's how you can quickly try the example code from above using a locally-hosted endpoint: 1. Create a virtual environment and install the `max` CLI: 2. Start an endpoint with a model that supports function calling: ```sh max serve --model meta-llama/Llama-3.1-8B-Instruct ``` 3. Wait until you see this message: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Then open a new terminal send a request with the `tools` parameter: First install the `openai` API (make sure your current working directory is still the `function-calling` directory): Then, create a program to send a request specifying the available `get_weather()` function: ```python title="function-calling.py" from openai import OpenAI import json def get_weather(city: str) -> str: print("Get weather:", city) client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="EMPTY") tools = [{ "type": "function", "function": { "name": "get_weather", "description": "Get current temperature for a given location.", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "City and country e.g. Bogotá, Colombia" } }, "required": [ "location" ], "additionalProperties": False }, "strict": True } }] messages = [ { "role": "user", "content": "What's the weather like in San Francisco today?" } ] completion = client.chat.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", messages=messages, tools=tools ) tool_call = completion.choices[0].message.tool_calls[0] args = json.loads(tool_call.function.arguments) result = get_weather(args["location"]) ``` Run it and the `get_weather()` function should print the argument received (make sure you're in the virtual environment—for example, first run `pixi shell`): ```sh python function-calling.py ``` ```output Get weather: San Francisco, USA ``` Use the following `curl` command to send a request specifying the available `get_weather()` function: ```bash curl -N http://0.0.0.0:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Llama-3.1-8B-Instruct", "stream": false, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the weather like in Boston today?"} ], "tools": [ { "type": "function", "function": { "name": "get_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. Los Angeles, CA" } }, "required": ["location"] } } } ], "tool_choice": "auto" }' ``` You should receive a response similar to this: ```json "tool_calls": [ { "id": "call_ac73df14fe184349", "type": "function", "function": { "name": "get_weather", "arguments": "{\"location\": \"Boston, MA\"}" } } ] ``` For a more complete walkthrough of how to handle a `tools_call` response and send the function results back to the LLM as input, see the [OpenAI docs about handling function calls](https://platform.openai.com/docs/guides/function-calling?api-mode=chat&example=get-weather#handling-function-calls). ## Next steps Now that you know the basics of function calling, you can get started with MAX on GPUs. export const docs = [ '../../serve/structured-output.mdx', '../../serve/offline-inference.mdx', '../../deploy/local-to-cloud.mdx', ]; --- ## Using LoRA adapters import InstallModular from '@site/docs/_includes/install-modular.mdx'; import MDXListing from '@site/src/components/Listing/MDXListing'; [LoRA (Low-Rank Adaptation)](https://arxiv.org/pdf/2106.09685) is a parameter-efficient fine-tuning (PEFT) technique that allows you to adapt a large model to new tasks or domains without modifying the original model weights. Instead of updating the full model, LoRA adds trainable pairs of rank decomposition matrices in parallel to existing weight matrices that capture task-specific behavior. These adapters are small, fast to train, and can be loaded at runtime, making them especially useful in production environments where model reuse, modularity, and memory efficiency are critical. MAX supports loading and switching between multiple LoRA adapters when serving a base model. :::note MAX supports model inference only. To serve a model with a LoRA adapter, you must provide a pre-trained adapter that is compatible with a specific base model. The adapter must use the `safetensors` weight format and be trained with [PEFT](https://github.com/huggingface/peft/tree/main). ::: ## When to use LoRA adapters LoRA adapters are ideal when you need to customize a foundation model for specific tasks or modalities without the overhead of full fine-tuning or maintaining multiple model variants. While prompt engineering can steer tone, format, or structure, LoRA adapters are better suited for cases where consistent, domain-specific behavior is required: - Text: Apply domain-specific fine-tuning. For example, using a `fingpt` LoRA adapter trained for financial jargon and reasoning. - Speech: Swap adapters to switch between different voice profiles in text-to-speech systems. - Vision: Use separate adapters for image style transfer or other workflows that involve changing visual characteristics. By encoding task-specific behavior into the model, LoRA adapters can reduce prompt length, eliminate the need for repeated context, and improve inference efficiency. LoRA adapters also enable you to serve a single base model with multiple specializations, minimizing memory usage and simplifying deployment. Adapters are especially effective at capturing specialized vocabulary, tone, or structure, and can help address model drift through targeted fine-tuning in production. ## How LoRA adapters work in MAX MAX loads LoRA adapters at model startup and applies them at inference time based on your input request. Each adapter is identified by a unique name and loaded from a local file path. :::note Currently MAX only supports LoRA adapters for Llama 3 models for query, key, value, and output (QKVO) layers. Stay tuned for additional updates on multi-modal model LoRA adapter support and additional layer implementations by following our [release notes](/max/changelog). ::: ### MAX CLI argument You can statically or dynamically load LoRA adapters when serving a model with the [`max` CLI](/max/cli). To use LoRA adapters, configure the appropriate [`max serve`](/max/cli/serve) arguments for your use case: - `--lora-paths {name}={path} {name}={path}`: **(optional)** A mapping from each adapter's name to its path, in the form of `{name}={path} {name}={path}`. - `--max-lora-rank`: **(optional, `int`)** Any LoRA adapter loaded when serving a model must have a rank less than or equal to `--max-lora-rank`. Use this to limit resource usage or enforce consistency across adapters. - `--max-num-loras`: **(optional, `int`)** The maximum number of LoRA adapters to manage concurrently. This should be configured based on your available GPU memory. - `--enable-lora`: **(optional)** Allows LoRA adapter use in inference requests and enables the API for dynamic loading and unloading. For more information, see [dynamic serving](/max/serve/lora-adapters#dynamic-serving). - `--no-enable-lora`: **(optional)** Does not allow the use of LoRA adapters. Models served with the `max` CLI use `--no-enable-lora` by default. Any LoRA-related arguments in an inference request are ignored. LoRA dynamic serving APIs are unavailable. - `--no-enable-prefix-caching`: LoRA adapters are not compatible with [prefix caching](/max/serve/prefix-caching), which is enabled by default. You must disable prefix caching to use LoRA adapters. Each `{name}` is a user-defined identifier for an adapter. Each `{path}` is a local path to the LoRA adapter's weights. Multiple adapters can be specified in a single command. ### Dynamic serving To dynamically load and unload LoRA adapters, you must first serve your model with the `--enable-lora` argument: ```bash max serve \ --model meta-llama/Llama-3.1-8B-Instruct \ --enable-lora \ --no-enable-prefix-caching ``` To dynamically load a LoRA adapter, send a POST request to the `v1/load_lora_adapter` endpoint specifying the LoRA adapter name and path: ```bash curl -X POST http://localhost:8000/v1/load_lora_adapter \ -H "Content-Type: application/json" \ -d '{ "lora_name": "example", "lora_path": "$HOME/.cache/huggingface/hub/models--example--lora-adapter/snapshots/abc123" }' ``` You should see the following response: ```output {"status":"success","message":"LoRA adapter 'example' loaded successfully"} ``` To unload a LoRA adapter, send a POST request to the `v1/unload_lora_adapter` endpoint specifying the name of the LoRA adapter to unload: ```bash curl -X POST http://localhost:8000/v1/unload_lora_adapter \ -H "Content-Type: application/json" \ -d '{"lora_name": "example"}' ``` You should see the following response if the adapter was unloaded successfully: ```output {"status":"success","message":"LoRA adapter 'example' unloaded successfully"} ``` ### Compatibility LoRA adapters must be saved in the `safetensors` format and trained using [PEFT](https://github.com/huggingface/peft/tree/main). At this time, only Llama 3 base models are supported. Only query, key, value, and output (QKVO) layer adapters are supported. Your adapter must only use the following layer projections: - `q_proj` - `k_proj` - `v_proj` - `o_proj` ## Quickstart We can quickly deploy Llama 3.1 8B Instruct using MAX as a backend with LoRA adapters. 1. Create a virtual environment and install the `max` CLI: 2. Find the path to your local LoRA adapter First, download an adapter that is trained on Llama 3.1 8B instruct and specifically fine-tunes QKVO layers. You can explore [available adapters](https://huggingface.co/models?other=base_model:adapter:meta-llama/Llama-3.1-8B-Instruct) on Hugging Face. ```bash pip install -U "huggingface_hub[cli]" hf download FinGPT/fingpt-mt_llama3-8b_lora ``` Copy the location of the downloaded snapshot. 3. Serve a model with a LoRA adapter available Change the `--lora-paths` path to the location of the downloaded LoRA adapter snapshot. :::note You can change `--max-num-loras` based on your available GPU memory and the number of LoRA adapters you want to enable. ::: ```bash max serve \ --model meta-llama/Llama-3.1-8B-Instruct \ --enable-lora \ --no-enable-prefix-caching \ --max-num-loras 10 \ --lora-paths finance=$HOME/.cache/huggingface/hub/models--FinGPT--fingpt-mt_llama3-8b_lora/snapshots/5b5850574ec13e4ce7c102e24f763205992711b7 ``` This command serves the base model and statically loads a LoRA adapter named `finance`. :::note You can optionally [dynamically load and unload LoRA adapters](/max/serve/lora-adapters#dynamic-serving). ::: 4. Run inference using a specific adapter When sending an inference request, specify the name of the adapter to apply. For example: ```bash curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Llama-3.1-8B-Instruct", "prompt": "What is an iron condor?", "max_tokens": 150, "lora": "finance" }' ``` This tells MAX to apply the `finance` adapter during inference. ## Next steps If you're using PEFT weights that have already been merged with the base model, check out our guide on [bringing your own model into MAX](/max/develop/max-pipeline-bring-your-own-model). If you're eager for LoRA support for a different base model, you can [check out the community](https://www.modular.com/community) to start contributing, or start a discussion in the [forum](https://forum.modular.com/). We'd love to hear from you! export const docs = [ '../../develop/max-pipeline-bring-your-own-model', '../../develop/serve-custom-model-architectures', ]; --- ## Offline inference import MDXListing from '@site/src/components/Listing/MDXListing'; import InstallModular from '@site/docs/_includes/install-modular.mdx'; Offline inference with MAX allows you to run large language models directly in Python without relying on external API endpoints. This is in contrast to online inference, where you would send requests to a remote service. ## When to use offline inference You'll want to use offline inference in scenarios where you want to perform model inference without the need for a separate model inference server. Typically this includes where you have to process a batch of inputs concurrently. This approach is beneficial for tasks that require high throughput and can be executed in a controlled environment, such as data preprocessing, model evaluation, or when working with large datasets that need to be processed in batches. ## How offline inference works The core of offline inference revolves around the [`LLM`](/max/api/python/entrypoints#max.entrypoints.llm.LLM) class which provides a Python interface to load and run language models. Specify the model from a Hugging Face repository or a local path and MAX handles the process of downloading the model. The [`PipelineConfig`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig) class allows you to specify parameters related to the inference pipeline, such as [`max_length`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig.max_length) and [`max_num_steps`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig.max_num_steps). The [`generate()`](/max/api/python/entrypoints#max.entrypoints.llm.LLM.generate) function is used to generate text from the model. :::note The Python API for offline inference currently supports text-only input and does not support multi-modal models. If you need to work with vision capabilities, see [Image to text](/max/inference/image-to-text). ::: ## Quickstart This quickstart demonstrates how to use offline inference using a Hugging Face model with MAX in Python. 1. Set up your project: 2. Create a file named `main.py` with the following code: ```python title="main.py" from max.entrypoints.llm import LLM from max.pipelines import PipelineConfig def main(): model_path = "google/gemma-3-12b-it" pipeline_config = PipelineConfig(model_path=model_path) llm = LLM(pipeline_config) prompts = [ "In the beginning, there was", "I believe the meaning of life is", "The fastest way to learn python is", ] print("Generating responses...") responses = llm.generate(prompts, max_new_tokens=50) for i, (prompt, response) in enumerate( zip(prompts, responses, strict=True) ): print(f"========== Response {i} ==========") print(prompt + response) print() if __name__ == "__main__": main() ``` :::note You need both a valid Hugging Face token and model access approval to serve Gemma 3. To create a Hugging Face user access token, see [Access Tokens](https://huggingface.co/settings/tokens). You can request model access through the [Gemma 3 Hugging Face model repository](https://huggingface.co/google/gemma-3-12b-it). ::: For offline inference, specific configuration parameters might vary between models. Always refer to the model's documentation for compatibility details and optimal configuration settings. 3. Run the script: ```sh python main.py ``` You should see a response similar to the following: ```output Generating responses... ========== Response 0 ========== In the beginning, there was Andromeda. The Andromeda galaxy, that is. It's the closest major galaxy to our own Milky Way, and it's been a source of fascination for astronomers and space enthusiasts for centuries. But what if I told you that there's ========== Response 1 ========== I believe the meaning of life is to find your gift. The purpose of life is to give it away to others. I believe that the meaning of life is to find your gift. The purpose of life is to give it away to others. I believe that the meaning of life is ========== Response 2 ========== The fastest way to learn python is to practice with real-world projects. Here are some ideas for projects that you can use to learn Python: 1. **Command Line Calculator**: Create a command line calculator that can perform basic arithmetic operations like addition, subtraction, multiplication, and division. ``` This code downloads the [`modularai/Llama-3.1-8B-Instruct-GGUF`](https://huggingface.co/modularai/Llama-3.1-8B-Instruct-GGUF) model (if not already downloaded) and runs inference locally. If you'd like to use a different model, see our [Model repository](https://builds.modular.com/?category=models). This example uses the Llama-3.1-8B-Instruct-GGUF model for this example because it's not gated, meaning it doesn't require authentication with Hugging Face. ## Next steps export const docs = [ '../../inference/text-to-text.mdx', '../../inference/image-to-text.mdx', '../../inference/embeddings.mdx' ] --- ## Prefix caching with PagedAttention Prefix caching is a technique that caches the key-value (KV) cache of existing inference requests so that new queries can reuse the context encoded in the KV cache if they share the same prefix. This eliminates redundant computations and improves performance for workloads with repeated prefixes. Prefix caching is enabled by default when serving a model with the [`max serve`](/max/cli/serve) CLI command. It can be disabled using the `--no-enable-prefix-caching` flag. ## When to use prefix caching Prefix caching speeds up the pre-fill stage of inference, which reduces time to first token (TTFT). It can also reduce memory usage within the KV cache for all requests, which makes room for scheduling larger batches and yielding higher throughput. Prefix caching can provide significant performance improvements in the following scenarios: 1. **Similar queries**: When a user repeatedly makes similar queries that use the same system prompt instructions, the KV cache of the prefix can be stored in advance to reduce redundant computation. 2. **Multi-round conversations**: In chat applications, users often ask follow-up queries related to previous inputs. Since the server releases KV cache memory after each request, prefix caching preserves computation from past conversation turns without requiring an explicit session. Prefix caching won't result in performance degradation. However, it also does not provide additional benefit in the following cases: - **Unique queries**: If new queries do not share prefixes with previous queries, there is no opportunity to reuse cached KV values, making prefix caching ineffective. - **Long response generation**: Prefix caching only speeds up the pre-fill phase of a request. If most of the time is spent generating new tokens (decoding), caching will have little impact. ## How prefix caching works Prefix caching works by storing the key-value (KV) cache for a prefix and applying it to future prompts that include the same prefix, reducing redundant computation. By default, the following flags are applied: - `--cache-strategy` : Prefix caching requires PagedAttention. To use PagedAttention, the cache strategy must be `paged`. - `--enable-prefix-caching`: Enables prefix caching. - `--kv-cache-page-size`: PagedAttention currently requires a page size that is a multiple of 128. Prefix caching with PagedAttention works on both CPU and GPU. To deploy a model with prefix caching using the `max` CLI, you can use the flag `--devices cpu` for CPU or `--devices gpu` for GPU workloads. If no flag is provided, the model runs on the first available GPU, or on the first available CPU if no GPUs are available. ## Quickstart Prefix caching is enabled by default when serving a model with MAX. To install the `max` CLI, see the [installation guide](/max/packages#install). The following command serves Gemma 3 with prefix caching enabled. The default KV cache page size for Gemma 3 is `256`. ``` max serve --model google/gemma-3-27b-it ``` :::note MAX does not support prefix caching for multimodal models. If you explicitly enable prefix caching for an incompatible model, you will see a log similar to the following: ``` WARNING: Architecture 'MllamaForConditionalGeneration' requires \ KVCacheConfig.enable_prefix_caching=False, overriding current value True ``` In this case, MAX automatically disables prefix caching and reverts to `--no-enable-prefix-caching`. ::: ### Disable prefix caching To disable prefix caching when serving a model, use the following command: ``` max serve --model google/gemma-3-27b-it --no-enable-prefix-caching ``` ### Enable prefix caching You can optionally explicitly enable prefix caching when serving your model with the [`max` CLI](/max/cli/serve). The following command represents the default MAX behavior: ``` max serve --model google/gemma-3-27b-it \ --cache-strategy paged \ --enable-prefix-caching \ --kv-cache-page-size 256 ``` ## Next steps Now that you know the basics of prefix caching and PagedAttention, you can get started with MAX on GPUs. MAX also includes a benchmarking script that allows you to evaluate throughput, latency, and GPU utilization metrics. You can use this script to track performance gains from prefix caching. For more detailed instructions on benchmarking, see the [`max benchamrk` docs](/max/cli/benchmark). import MDXListing from '@site/src/components/Listing/MDXListing'; export const docs = [ '../../serve/speculative-decoding.mdx', '../../deploy/benchmark.mdx', '../../deploy/local-to-cloud.mdx', ]; --- ## Speculative decoding import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import MDXListing from '@site/src/components/Listing/MDXListing'; import InstallOpenAI from '@site/docs/_includes/install-openai.mdx'; Speculative decoding is an algorithm designed to accelerate the decoding process for large language models without sacrificing the quality of the generated text or requiring modifications to the models themselves. This technique employs a smaller, faster **draft model** to generate several potential next tokens in parallel, which are then efficiently validated against a larger, more powerful target model using a modified rejection sampling technique. This leads to reduced overall latency and improved throughput during token generation. By accepting correct predictions and only resampling when necessary, speculative decoding achieves a significant speedup in token generation, effectively bypassing memory bandwidth limitations often encountered during standard autoregressive decoding. :::caution Speculative decoding with MAX is still in preview and some aspects may change as we refine the implementation. Expect **ongoing** improvements and potential adjustments based on feedback and performance optimizations. ::: ## When to use speculative decoding You'll want to use speculative decoding when your primary goal is to accelerate the decoding process of large language models and reduce latency. For example, if you are using a 405 billion parameter model, you can use speculative decoding to reduce latency by using a 135 million parameter draft model. ## How speculative decoding works By default, speculative decoding is disabled in MAX. It can be enabled using the `--draft-model-path` flag. This flag takes a path to a model that will be used to generate speculative tokens. This is the model name as it appears on Hugging Face or as a path to a local directory containing a model. All model-specific parameters can be prefixed with `--draft-` to configure the draft model independently from the main model. For example: - `--draft-model-path`: Path to the draft model - `--draft-quantization-encoding`: Quantization encoding for the draft model - `--draft-weight-path`: Path to draft model weights The performance of speculative decoding primarily depends on two factors: - **Acceptance rate**: How often the target model confirms the draft model's predictions. - **Token generation pattern**: The system is optimized when more draft tokens can be evaluated in a single step of the target model. This is controlled by the `--max-num-steps` parameter, which sets the maximum number of tokens the draft model generates before verification by the target model. ## Quickstart You can use speculative decoding with MAX to accelerate model inference by using a smaller draft model to predict tokens that are verified by the main model. Serve your model with MAX and specify the draft model path using the `--draft-model-path` flag: ```sh max serve --model HuggingFaceTB/SmolLM2-360M-Instruct \ --draft-model-path HuggingFaceTB/SmolLM2-135M-Instruct \ --device-memory-utilization=0.6 \ --max-num-steps=5 \ --no-enable-chunked-prefill ``` The endpoint is ready when you see the URI printed in your terminal: ```output Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` Once the model is served, you can make requests to the API endpoints. To interact with MAX's OpenAI-compatible endpoints, install the OpenAI Python API: Then create a new Python file and import the `openai` package: ```python from openai import OpenAI client = OpenAI( base_url="http://localhost:8000/v1", # Your MAX endpoint api_key="EMPTY", # API key can be any string when using MAX locally ) # Make a chat completion request response = client.chat.completions.create( model="HuggingFaceTB/SmolLM2-360M-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What are the benefits of speculative decoding?"}, ], max_tokens=500, ) # Print the response print(response.choices[0].message.content) ``` In a new terminal, make a chat completion request using curl: ```sh curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "HuggingFaceTB/SmolLM2-360M-Instruct", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What are the benefits of speculative decoding?"} ], "max_tokens": 500 }' ``` You can also use the `generate` command to generate text: ```sh max generate --model HuggingFaceTB/SmolLM2-360M-Instruct \ --draft-model-path HuggingFaceTB/SmolLM-135M \ --max-length=200 \ --prompt="What are the benefits of speculative decoding?" \ --device-memory-utilization=0.6 \ --devices=gpu \ --no-enable-chunked-prefill ``` ## Next steps Now that you know the basics of speculative decoding, you can get started with MAX on GPUs. export const docs = [ '../../serve/prefix-caching.mdx', '../../deploy/benchmark.mdx', '../../deploy/local-to-cloud.mdx', ]; --- ## Structured output import MDXListing from '@site/src/components/Listing/MDXListing'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; MAX supports the generation of structured output using [llguidance](https://github.com/guidance-ai/llguidance) as a backend. Structured output, also sometimes referred to as constrained decoding, allows users to enforce specific output formats, ensuring structured and predictable responses from a model. :::note Structured output is compatible with GPU deployments and MAX models only. Support for PyTorch models and CPU deployments is in progress. ::: ## When to use structured output If you want to structure a model's output when it responds to a user, then you should use a structured output `response_format`. If you are connecting a model to tools, functions, data, or other systems, then you should use [function calling](/max/serve/function-calling) instead of structured outputs. ## How structured output works To use structured output, include the `--enable-structured-output` flag when serving your model with the [`max` CLI](/max/cli). ```bash max serve \ --model "google/gemma-3-27b-it" \ --enable-structured-output ``` :::note The examples on this page use image input with a multimodal model. For environment setup and additional image input options (including local files), see [Image to text](/max/inference/image-to-text). ::: Both the [`/chat/completions`](/max/api/serve#operation/createChatCompletion) and [`/completions`](/max/api/serve#operation/createCompletion) API endpoints are compatible with structured output. You can define your structured output response format in two ways: 1. [JSON schema](#json-schema): Specify the schema directly in your request. 2. [Pydantic](#pydantic): Use Pydantic to define and validate your schema as a Python class. We recommend testing your structured output responses thoroughly as they are sensitive to the way the model was trained. ### JSON schema To specify structured output within your inference request, use the following format: :::note You can increase the accuracy of structured output responses by mentioning JSON output specifications in your system prompt. ::: ```python title="structured-image-description.py" from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") response = client.chat.completions.create( model="google/gemma-3-27b-it", messages=[ { "role": "system", "content": "You are an assistant that analyzes images and returns structured descriptions." }, { "role": "user", "content": [ { "type": "text", "text": "Analyze this image and describe what you see." }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], max_tokens=300, response_format={ "type": "json_schema", "json_schema": { "name": "ImageAnalysis", "schema": { "type": "object", "properties": { "description": {"type": "string"}, "subjects": { "type": "array", "items": {"type": "string"} }, "colors": { "type": "array", "items": {"type": "string"} }, "setting": {"type": "string"}, "mood": {"type": "string"} }, "required": ["description", "subjects", "colors", "setting", "mood"], "additionalProperties": False } } } ) print(response.choices[0].message.content) ``` ```bash curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-3-27b-it", "messages": [ { "role": "system", "content": "You are an assistant that analyzes images and returns structured descriptions." }, { "role": "user", "content": [ { "type": "text", "text": "Analyze this image and describe what you see." }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], "max_tokens": 300, "response_format": { "type": "json_schema", "json_schema": { "name": "ImageAnalysis", "schema": { "type": "object", "properties": { "description": { "type": "string" }, "subjects": { "type": "array", "items": { "type": "string" } }, "colors": { "type": "array", "items": { "type": "string" } }, "setting": { "type": "string" }, "mood": { "type": "string" } }, "required": ["description", "subjects", "colors", "setting", "mood"], "additionalProperties": false } } } }' ``` Instead of a typical text response from the model, the `response_format` schema defined above results in a JSON-formatted structured output such as the following: ```output { "description": "A full-body shot of Peter Rabbit, the fictional character, standing on a dirt path. He is dressed in a blue jacket with brass buttons over a white shirt and a small yellow tie. He also wears brown pants and appears to be holding a small basket. The background consists of a rustic stone house with a thatched roof, a winding dirt road, green fields, and rolling hills under a bright sky. Wildflowers in shades of purple and white line the path in the foreground.", "subjects": [ "rabbit", "house", "path", "fields", "hills", "flowers", "basket" ], "colors": [ "blue", "brown", "green", "white", "yellow", "purple" ], "setting": "Rural countryside", "mood": "Whimsical, charming, idyllic" } ``` ### Pydantic For production Python code, you can define your structured output using [Pydantic](https://docs.pydantic.dev/latest/install/). This gives you type-safe attribute access and automatic validation instead of manually parsing JSON strings. Here's an example using a Pydantic [`BaseModel`](https://docs.pydantic.dev/latest/api/base_model/) to analyze an image and return a validated response: ```python title="structured-image-analysis.py" from pydantic import BaseModel from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") class ImageAnalysis(BaseModel): description: str subjects: list[str] colors: list[str] setting: str mood: str completion = client.chat.completions.parse( model="google/gemma-3-27b-it", messages=[ { "role": "system", "content": "You are an assistant that analyzes images and returns structured descriptions." }, { "role": "user", "content": [ { "type": "text", "text": "Analyze this image and describe what you see." }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" } } ] } ], max_tokens=300, response_format=ImageAnalysis, ) analysis = completion.choices[0].message.parsed print(analysis) ``` ### Supported models All text generation models support structured output with MAX. As new models are added, they will also be compatible with structured output. This functionality is implemented at the pipeline level, ensuring consistency across different models. However, structured output currently doesn't support PyTorch models or CPU deployments—only [MAX models](/max/model-formats#max-graph) deployed on GPUs. ## Next steps Next, try processing local image files, running batch inference offline, or deploying to the cloud. export const docs = [ '../../inference/image-to-text.mdx', '../../serve/offline-inference.mdx', '../../deploy/local-to-cloud.mdx', ]; --- ## Mojo changelog This is a list of changes to the Mojo language, standard library, and tools. To check your current version, run `mojo --version`. To install or update Mojo, see the [install guide](/mojo/manual/install). ## Nightly: v0.26.2 This version is still a work in progress. ### Language enhancements * `@register_passable("trivial")` is now deprecated, conform to `TrivialRegisterPassable` trait instead. The decorator will be removed after next release. * `@register_passable` is now deprecated, conform to `RegisterPassable` trait instead. The decorator will be removed after next release. * Mojo now supports more flexible default arguments and parameters, which can mismatch on declared type when their types are parametric. This allows inferring parameters from these when they are used as a default value, for example: ```mojo fn take_string_slice[o: ImmOrigin](str: StringSlice[o] = ""): ... fn use_it(): take_string_slice() # Ok, defaults to empty string, inferring "o". # Explicit calls also work of course. take_string_slice(StaticString("hello")) # Default value is checked for validity at the call site. fn defaultArgumentBadType2[T: AnyType](a: T = 1.0): pass fn callDefaultArgumentBadType2(): # Ok! defaultArgumentBadType2[Float64]() # error: value passed to 'a' cannot be converted from 'FloatLiteral[1]' to 'Int' defaultArgumentBadType2[Int]() ``` ### Language changes * Slice literals in subscripts has changed to be more similar to collection literals. They now pass an empty tuple as a required `__slice_literal__` keyword argument to disambiguate slices. If you have defined your own range types, please add a `__slice_literal__: () = ()` argument to their constructors. * `trait` declarations no longer automatically inherit from `ImplicitlyDestructible`. `struct` declarations are not changed, and continue to inherit from `ImplicitlyDestructible`. Previously, the `@explicit_destroy` annotation was required to opt-out of `ImplicitlyDestructible` conformance. Now, if a trait's usage depends on implicit destructibility, it must opt-in explicitly: ```mojo # Before trait Foo: ... # After: trait Foo(ImplicitlyDestructible): ... ``` Conversely, if a trait wanted to support non-implicitly-destructible types, it no longer needs to be annotated with `@explicit_destroy`: ```mojo # Before @explicit_destroy trait Foo: ... # After trait Foo: ... ``` Making `struct` continue to inherit from `ImplicitlyDestructible` and not `trait` is intended to balance usability and familiarity in the common case, with the need to foster broad Mojo ecosystem support for explicitly destroyed types. It's not a problem if the majority of `struct` types are `ImplicitlyDestructible` in practice. However, if many ecosystem libraries are written with unnecessary `ImplicitlyDestructible` bounds, that would hamper the usability of any individual `struct` type that opts-in to being explicitly destroyed. Libraries with generic algorithms and types should be written to accomodate linear types. Making `ImplicitlyDestructible` opt-in for traits encourages a default stance of support, with specific types and functions only opting-in to the narrower `ImplicitlyDestructible` requirement if they truly need it. The majority of generic algorithms that take their inputs by reference should not be affected. * Unstable `__comptime_assert` syntax is now finalized as `comptime assert`. A deprecation warning is emitted with a fixit for the old syntax. ### Library changes * The `builtin.math` module has been merged into `math`. The traits `Absable`, `DivModable`, `Powable`, `Roundable` and functions `abs()`, `divmod()`, `max()`, `min()`, `pow()`, `round()` are now part of the `math` module and continue to be available in the prelude. Code that explicitly imported from `builtin.math` should update to import from `math` instead. * The `ffi` module is now a top-level module in the standard library, rather than being nested under `sys`. This improves discoverability of FFI functionality. Update your imports from `from sys.ffi import ...` to `from ffi import ...`. * The `itertools` module now includes three new iterator combinators: * `cycle(iterable)`: Creates an iterator that cycles through elements indefinitely * `take_while[predicate](iterable)`: Yields elements while the predicate returns True * `drop_while[predicate](iterable)`: Drops elements while the predicate returns True, then yields the rest * Math functions in `std.math` (`exp`, `exp2`, `log2`, `erf`, `tanh`, `sin`, `cos`, `tan`, `acos`, `asin`, `atan`, `atan2`, `acosh`, `asinh`, `atanh`, `cosh`, `sinh`, `expm1`, `log10`, `log1p`, `logb`, `cbrt`, `erfc`, `j0`, `j1`, `y0`, `y1`) now use `where dtype.is_floating_point()` clauses on their signatures instead of `__comptime_assert` checks in their bodies. This provides better compile-time error messages at the call site. Callers using these functions with generic `dtype` parameters may need to add evidence proving (either a `where` clause or `__comptime_assert`) that their type is floating point. * Many kernels in `nn` have been migrated to use `TileTensor`. We will have more documentation on `TileTensor` and its uses over the coming weeks. * `InlineArray` now requires explicitly using literals for construction. E.g. ```Mojo var a: InlineArray[UInt8, 4] = [1, 2, 3, 4] # instead of InlineArray[UInt8, 4](1, 2, 3, 4) ``` * The following types now conform to `Writable` and have custom implementations of `write_to` and `write_repr_to`. * `Tuple` * `Variant` * The `testing` module now provides `assert_equal` and `assert_not_equal` overloads for `Tuple`, enabling direct tuple-to-tuple comparisons in tests instead of element-by-element assertions. Element types must conform to `Equatable & Writable`. * The `__reversed__()` method on `String`, `StringSlice`, and `StringLiteral` has been deprecated in favor of the new `codepoints_reversed()` method. The new method name makes it explicit that iteration is over Unicode codepoints in reverse order, maintaining consistency with the existing `codepoints()` and `codepoint_slices()` methods. The deprecated `__reversed__()` methods will continue to work but will emit deprecation warnings. * The `StringSlice` constructor from `String` now propagates mutability. If you have a mutable reference to a `String`, `StringSlice(str)` returns a mutable `StringSlice`. The `String.as_string_slice()` method is now deprecated in favor of the `StringSlice(str)` constructor, and `String.as_string_slice_mut()` has been removed. * `String.ljust`, `String.rjust`, and `String.center` have been renamed to `String.ascii_ljust`, `String.ascii_rjust`, and `String.ascii_center`. Likewise for their mequivalents on `StringSlice` and `StaticString` * `String.resize` will now panic if the new length would truncate a codepoint. Previously it would result in a string with invalid UTF-8. * `String.resize` will now panic if `fill_byte` is >=128. Previously it would create invalid UTF-8. * Subscripting into `String` and `StringSlice` will now panic if the index falls in the middle of a UTF-8 encoded code-point. Previously they would return invalid UTF-8. This panic is unconditional. Use `.as_bytes()[...]` if you really want the previous behavior. * `StringSlice[byte=]` subscripting now returns a `StringSlice` instead of a `String`, This is consistent with range-based subscripting. * Subscripting `String` and `StringSlice` by byte position will now return an entire Unicode codepoint. Previously it would return a single byte, and produce invalid UTF-8 if the index fell on the starting byte of a multi-byte codepoint. * The following types now correctly implement `write_repr_to` * `List`, `Set` * `assert_equal` and `assert_not_equal` now work with types implementing `Writable`. * All traits and structs with `@register_passable("trivial")` decorator are now extending `TrivialRegisterPassable` trait. The decorator is removed from them. * `String`, `StringSlice`, and `StringLiteral`'s `.format()` method now require their arguments to be `Writable`. * Formatting compile-time format strings (`StringLiteral`s) no longer allocates memory! It uses `global_constant` to store what would be heap allocated parsed formatting data. * `Int.__truediv__` now performs truncating integer division, returning `Int` instead of the previously deprecated `Float64`. Use explicit `Float64` casts for floating-point division. ### Tooling changes * The Mojo compiler now accepts conjoined `-D` options in addition to the non-conjoined form as before. Now, both `-Dfoo` and `-D foo` are accepted. * `mojo build` now supports several `--print-*` options for discovering target configuration and supported architectures: * `--print-effective-target`: Shows the resolved target configuration after processing all command-line flags. * `--print-supported-targets`: Lists all available LLVM target architectures. * `--print-supported-cpus`: Lists valid CPU names for a given target triple (requires `--target-triple`). * `--print-supported-accelerators`: Lists all supported GPU and accelerator architectures (NVIDIA, AMD, Apple Metal). ### 🛠️ Fixed * [Issue #5845](https://github.com/modular/modular/issues/5845): Functions raising custom type with conversion fails when returning StringSlice * [Issue #5875](https://github.com/modular/modular/issues/5875): Storing `SIMD[DType.bool, N]` with width > 1 to a pointer and reading back element-wise now returns correct values. * `StringSlice.find`: Fixed integer overflow bug in SIMD string search that caused searches to fail when searching for strings longer than `simd_width_of[DType.bool]()` and haystacks larger than UInt16.MAX. ## v0.26.1 (2026-01-29) ### ✨ Highlights {#26-1-highlights} * **Expanded reflection module.** The `reflection` module now provides extensive compile-time introspection: struct field enumeration, types, and byte offsets; source location tracking; and trait conformance checking on dynamically obtained types. These APIs enable advanced metaprogramming patterns like automatic serialization and debug formatting. See [Reflection and introspection](#26-1-reflection). * **Explicitly-destroyed types.** Mojo now has first-class support for explicitly-destroyed types (sometimes referred to as "linear types"). The `AnyType`, `Movable`, and `Copyable` traits no longer require a `__del__()` method; use `ImplicitlyDestructible` when you need implicit destruction. Explicitly-destroyed types let you encode powerful invariants requiring explicit resource handling. See [Language changes](#26-1-language-changes) and [Explicitly-destroyed types](#26-1-explicitly-destroyed-types). * **Typed errors.** Functions can now specify what type they raise instead of defaulting to `Error` (for example, `fn foo() raises CustomError -> Int`). Typed errors are highly efficient—they compile to an alternate return value with no stack unwinding—making them suitable for GPU and embedded targets. See [Language enhancements](#26-1-language-enhancements). * **Traits with default implementations.** The `Hashable`, `Writable`, and `Equatable` traits now provide default implementations that automatically derive behavior from struct fields using reflection. Simple structs can conform to these traits without writing any boilerplate—just ensure all fields conform to the same trait. See [Traits with default implementations](#26-1-trait-defaults). * **`UInt` type redesign.** The `UInt` struct has been replaced by a type alias to `Scalar[DType.uint]`, enabling more powerful generic programming over unsigned SIMD data types of machine word size. See [Other library changes](#26-1-other-library). * **String UTF-8 safety.** `String` now offers three explicit constructors for raw bytes: `from_utf8=` (validates and raises on error), `from_utf8_lossy=` (replaces invalid sequences with �), and `unsafe_from_utf8=` (no validation). This makes UTF-8 handling guarantees explicit at construction time. See [String and text](#26-1-string-and-text). ### Documentation {#26-1-documentation} * The new [Mojo quickstart](/mojo/manual/quickstart) page provides a "5-minute" introduction to installing Mojo and exploring basic language features. * The new [Reflection](/mojo/manual/reflection) page describes Mojo's new compile-time reflection capabilities, including example use cases. * Added a section to the *Value destruction* page about how to implement and use [explicitly-destroyed types](/mojo/manual/lifecycle/death#explicitly-destroyed-types). * The new [Jupyter Notebooks](/mojo/tools/notebooks) page provides step-by-step instructions for programming with Mojo in Google Colab and local JupyterLab environments, including GPU programming examples. * The new [Materialization](/mojo/manual/metaprogramming/materialization) page describes how to make compile-time values available at run time. ### Language enhancements {#26-1-language-enhancements} * Mojo now supports raising "typed errors", where a function can specify what type it raises instead of defaulting to the [`Error`](/mojo/std/builtin/error/Error) type. This is done by specifying it after the `raises` keyword, for example, `fn foo() raises CustomError -> Int`. Raised errors in Mojo are very efficient - they work as an alternate return value: for example, a function like `fn () raises Int -> Float32:` compiles into code that returns either an `Int` or a `Float32` and uses an implicit boolean result to determine which one is valid - there is no expensive stack unwinding or slow dynamic logic that is implied. This means that thrown errors work fine on GPUs and other embedded targets. The 'caught' type in a `try` block is automatically inferred to be the first thrown type inside of the `try` body, for example: ```mojo try: print(foo()) except err: # "err" is typed as CustomError print(err) ``` Typed throws "just work" with generics, allowing the definition of higher order functions like: ```mojo fn parametric_raise_example[ErrorType: AnyType](fp: fn () raises ErrorType) raises ErrorType: # ... presumably some iteration or other exciting stuff happening here. fp() ``` This dovetails with other support to allow contextually generic thrown types, for example: ```mojo fn call_parametric_raise_example[GenTy: AnyType](func_ptr: fn () raises GenTy): fn raise_int() raises Int: pass try: parametric_raise_example(raise_int) except err_int: # Typed as Int ref x: Int = err_int fn raise_string() raises String: pass try: parametric_raise_example(raise_string) except err_string: # Typed as String ref s: String = err_string try: parametric_raise_example(func_ptr) except err_gen: # Typed as GenTy ref s: GenTy = err_gen # Non-raising functions infer an error type of `Never`, allowing these # functions to propagate non-raisability across generic higher-order # functions conveniently. fn doesnt_raise(): pass # Note this isn't in a try block. Mojo knows 'parametric_raise_example' # doesn't raise because the 'doesnt_raise' function doesn't. parametric_raise_example(doesnt_raise) ``` As part of this, context managers have been extended to support typed throws, and can also infer an error type if they need to handle it, for example: ```mojo struct MyGenericExitCtxtMgr: # Called on entry to the with block. fn __enter__(self): ... # Called on exit from the with block when no error is thrown. fn __exit__(self): ... # Called on exit from the with block if an error is thrown. fn __exit__[ErrType: AnyType](self, err: ErrType) -> Bool: ... ``` * Mojo now supports a [`Never`](/mojo/std/builtin/type_aliases/#never) type, which can never be instantiated. This type can be used for functions (like [`abort()`](/mojo/std/os/os/abort/)) which do not have a normal return value, and for functions that are guaranteed to raise without returning a normal value. Functions that are declared to raise `Never` (and generic functions instantiated with `Never` as their error type) compile into the same ABI as functions that don't `raise`. * Mojo now allows the use of a `comptime(x)` expression to force a subexpression to be evaluated at compile time. This can help make working with certain types more elegant when you can't (or don't want to) materialize them into a runtime value. For example, if you just want the size from a compile time layout: ```mojo fn takes_layout[a: Layout](): # materializes entire layout value just to get the size out of it print(a.size()) # Could already work around this with a comptime declaration, verbosely. comptime a_size = a.size() print(a_size) # Can now tell Mojo to evaluate the expression at comptime. print(comptime(a.size())) ``` * Mojo now differentiates between `...` and `pass` in trait methods. The use of `...` continues to denote no default implementation—`pass` now specifies a default do-nothing implementation. For example: ```mojo trait T: # No default implementation fn foo(self): ... # Default implementation that does nothing fn bar(self) : pass ``` The compiler will error on the use of `pass` to define a default implementation for a trait method with results: ```mojo trait T: foo.mojo:2:26: error: trait method has results but default implementation returns no value; did you mean '...'? fn foo(self) -> Int: pass ^ trait.mojo:2:8: note: in 'foo', declared here fn foo(self) -> Int: pass ^ ``` * Mojo now allows implicit conversions between function types from a non-raising function to a raising function. It also allows implicit conversions between function types whose result types are implicitly convertible. ```mojo fn takes_raising_float(a: fn () raises -> Float32): ... fn returns_int() -> Int: ... fn example(): # This is now ok. takes_raising_float(returns_int) ``` * Mojo now allows functions that return references to convert to functions that return values if the type is implicitly copyable or implicitly convertible to the destination type: ```mojo fn fn_returns_ref(x: SomeType) -> ref [x.field] Int: ... fn examples(): # OK, Int result from fn_returns_ref can be implicitly copied. var f1 : fn (x: SomeType) -> Int = fn_returns_ref # OK, Int result from fn_returns_ref implicitly converts to Float64. var f2 : fn (x: SomeType) -> Float64 = fn_returns_ref ``` * Context managers (used in `with` statements) can now define consuming exit methods—that is, `fn __exit__(var self)`—which can be useful for explicitly-destroyed context managers. This also works with `deinit`. * The `deinit` argument convention can now be applied to any argument of a struct method, but the argument type still must be of the enclosing struct type. * Mojo now supports the `...` expression. It is a logically empty value of [`EllipsisType`](/mojo/std/builtin/type_aliases/#ellipsistype). It can be used in overloaded functions, for example, `getitem()` calls: ```mojo struct YourType: fn __getitem__(self, idx: Int) -> Int: # ... behavior when passed x[i] fn __getitem__(self, idx: EllipsisType) -> Int: # ... behavior when passed x[...] ``` ### Language changes {#26-1-language-changes} * The Mojo language basic trait hierarchy has changed to expand first-class support for explicitly-destroyed types (sometimes referred to as "linear types"). The [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Movable`](/mojo/std/builtin/value/Movable), and [`Copyable`](/mojo/std/builtin/value/Copyable) traits no longer require that a type provide a `__del__()` method that may be called by the compiler implicitly whenever an owned value is unused. Instead, the [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) trait should be used in generic code to require that a type is implicitly destructible. Explicitly-destroyed types enable Mojo programs to encode powerful invariants in the type system, by modeling a type in such a way that a user is required to take an action "in the future", rather than simply implicitly dropping an instance "on the floor". Code using `T: AnyType` can change to use `T: ImplicitlyDestructible` to preserve its pre-existing behavior following this change. Relatedly, the [`UnknownDestructibility`](/mojo/std/builtin/anytype/#unknowndestructibility) trait is now no longer required, as it is equivalent to the new `AnyType` behavior. * Mojo no longer supports overloading functions on parameters alone: it will not try to disambiguate between `fn foo[a: Int8]():` and `fn foo[a: Int32]():` for example. Mojo never fully implemented the previous support in a reliable way, and removing this simplifies the language. It still supports overloading on function arguments of course. * The `__next_ref__()` method in for-each loops has been removed. Now you can implement the `__next__()` method of your iterator to return either a value or a reference. When directly using the collection, Mojo will use the ref-returning variant, but will allow it to conform to [`Iterator`](/mojo/std/iter/Iterator) for use with generic algorithms (which use a copied value). * The `origin_of(x)` operator now returns a value of type `Origin` instead of an internal MLIR type, and aliases like `ImmutOrigin` are now `Origin` type as well. * The `Origin.cast_from[x]` syntax has been replaced with a safe implicit conversion from any origin to an immutable origin (`ImmutOrigin(x)`) and an explicit unsafe conversion (`unsafe_origin_mutcast[origin, mut=m]`). * The `*_` and `**_` syntax for explicitly unpacked parameters has been replaced with a simplified `...` syntax. Instead of `T[4, 5, *_, **_]` you can now use `T[4, 5, ...]`. The `...` delays binding of both keyword and non-keyword parameters. * The compiler will now warn on the use of `alias` keyword and suggest `comptime` instead. * The compiler will now warn on unqualified access to struct parameters, for example ```mojo @fieldwise_init struct MyStuff[my_param: Int]: fn give_me_stuff(self) -> Int: # Warning: unqualified access to struct parameter 'my_param'; use 'Self.my_param' instead return my_param ``` * The Mojo compiler generates more clear error messages when diagnosing invalid calls: it mentions the argument name, instead of "argument #4". ### Library changes {#26-1-library-changes} #### Reflection and introspection {#26-1-reflection} * The [`reflection`](/mojo/std/reflection/) module has been significantly expanded with new compile-time introspection capabilities. The module has moved from `compile.reflection` to a top-level `reflection` module (update imports from `from compile.reflection import ...` to `from reflection import ...`). Internally, the module is now organized into `type_info` and `struct_fields` submodules, though the public API via `from reflection import ...` remains unchanged. **Struct field introspection** - New APIs for compile-time struct analysis: * `struct_field_count[T]()` - Returns the number of fields in a struct * `struct_field_names[T]()` - Returns field names as [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) `[StaticString, N]` * `struct_field_types[T]()` - Returns a variadic of all field types * `struct_field_index_by_name[T, name]()` - Returns field index by name * `struct_field_type_by_name[T, name]()` - Returns field type wrapped in `ReflectedType` These APIs work with both concrete types and generic type parameters: ```mojo fn print_fields[T: AnyType](): comptime names = struct_field_names[T]() comptime types = struct_field_types[T]() @parameter for i in range(struct_field_count[T]()): print(names[i], get_type_name[types[i]]()) ``` **Field access by index** - Two new magic functions enable index-based field access without copying: * `__struct_field_type_at_index(T, idx)` - Returns field type at index * `__struct_field_ref(idx, ref s)` - Returns a reference to the field Unlike `kgen.struct.extract` which copies, `__struct_field_ref()` returns a reference, enabling reflection utilities to work with non-copyable types: ```mojo fn print_all_fields[T: AnyType](ref s: T): comptime names = struct_field_names[T]() @parameter for i in range(struct_field_count[T]()): print(names[i], "=", __struct_field_ref(i, s)) ``` **Field byte offsets** - `offset_of[T, name=field_name]()` returns the byte offset of a named field within a struct, enabling no-copy serialization and other low-level memory operations. The offset is computed at compile time using the target's data layout, correctly accounting for alignment padding. This is analogous to C/C++'s `offsetof` and Rust's `offset_of!` macro. An `offset_of[T, index=i]()` overload is also available to look up by field index. ```mojo from reflection import offset_of struct Point: var x: Int # offset 0 var y: Float64 # offset 8 (aligned) fn main(): comptime x_off = offset_of[Point, name="x"]() # 0 comptime y_off = offset_of[Point, name="y"]() # 8 ``` **Type introspection utilities:** * `is_struct_type[T]()` - Returns `True` if `T` is a Mojo struct type. Useful for guarding reflection code that uses struct-specific APIs to avoid compiler errors on non-struct types (for example, MLIR primitive types). Use `@parameter if` since these APIs are evaluated at compile time. ([Issue #5734](https://github.com/modular/modular/issues/5734)) * `get_base_type_name[T]()` - Returns the unqualified name of a parameterized type's base type. For example, `get_base_type_name[List[Int]]()` returns `"List"`. Useful for identifying collection types regardless of element types. ([Issue #5735](https://github.com/modular/modular/issues/5735)) **Source location introspection:** * [`SourceLocation`](/mojo/std/reflection/location/SourceLocation) - A struct holding filename, line, and column information * [`source_location()`](/mojo/std/reflection/location/source_location) - Returns the location where it's called * `call_location()` - Returns the location where the caller was invoked (requires the caller to be `@always_inline`) These were previously internal APIs (`_SourceLocation`, `__source_location`, `__call_location`) in `builtin._location`. The old module has been removed. ```mojo from reflection import source_location, call_location, SourceLocation fn main(): var loc = source_location() print(loc) # main.mojo:5:15 @always_inline fn log_here(): var caller_loc = call_location() print("Called from:", caller_loc) ``` Note: These APIs do not work correctly in parameter expressions (comptime contexts return placeholder values). **Trait conformance checking** - The `conforms_to()` builtin now accepts types from reflection APIs like `struct_field_types[T]()`, enabling conformance checks on dynamically obtained field types: ```mojo @parameter for i in range(struct_field_count[MyStruct]()): comptime field_type = struct_field_types[MyStruct]()[i] @parameter if conforms_to(field_type, Copyable): print("Field", i, "is Copyable") ``` #### Traits with default implementations {#26-1-trait-defaults} * Several traits now have default implementations that use reflection to automatically derive behavior from struct fields. This means simple structs can conform to these traits without implementing any methods - all fields just need to conform to the same trait: **[`Hashable`](/mojo/std/hashlib/hash/Hashable)** - Default `__hash__()` hashes all fields: ```mojo @fieldwise_init struct Point(Hashable): var x: Float64 var y: Float64 hash(Point(1.5, 2.7)) # Works automatically ``` **[`Writable`](/mojo/std/format/Writable)** - Default `write_to()` formats all fields: ```mojo @fieldwise_init struct Point(Writable): var x: Float64 var y: Float64 print(Point(1.5, 2.7)) # Point(x=1.5, y=2.7) ``` **[`Equatable`](/mojo/std/builtin/comparable/Equatable)** - Default `__eq__()` compares all fields: ```mojo @fieldwise_init struct Point(Equatable): var x: Int var y: Int print(Point(1, 2) == Point(1, 2)) # True ``` Note: The default `Equatable` performs memberwise equality, which may not be appropriate for types with floating-point fields (due to NaN semantics). Override any of these methods for custom behavior. #### Collections and iterators {#26-1-collections} * Removed [`List`](/mojo/std/collections/list/List) variadic initializer. * Statements like: ```mojo var x = List[Int32](1, 2, 3) ``` can be updated to: ```mojo var x: List[Int32] = [1, 2, 3] ``` * Expressions like: ```mojo var x = foo(List[Float32](1, 2, 3)) ``` can be updated to move the explicit type "hint" around the first element: ```mojo var x = foo([Float32(1), 2, 3]) ``` * Expressions like: ```mojo var data = Span(List[Byte](1, 2, 3)) ``` can be updated to move the explicit element type to the `Span`: ```mojo var data = Span[Byte]([1, 2, 3]) ``` * `List` slicing without a stride now returns a [`Span`](/mojo/std/memory/span/Span), instead of a `List` and no longer allocates memory. * [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) no longer conforms to [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable/). Users must explicitly copy arrays or take references. * [`IndexList`](/mojo/std/utils/index_/IndexList/) is no longer implicitly constructible from `Int`. Previously, the fill constructor (which broadcasts a single `Int` to all elements) was marked `@implicit`, allowing code like `var x: IndexList[3] = 5` which would create `(5, 5, 5)`. This implicit conversion has been removed to improve type safety. Use explicit construction instead: `IndexList[3](5)`. * [`Dict`](/mojo/std/collections/dict/Dict) now raises a custom [`DictKeyError`](/mojo/std/collections/dict/DictKeyError) type on failure, making lookup failures more efficient to handle. ```mojo var d = Dict[String, Int]() var key = "missing_key" try: _ = d[key] except e: print(e) # Prints: DictKeyError ``` * New [`ContiguousSlice`](/mojo/std/builtin/builtin_slice/ContiguousSlice) and [`StridedSlice`](/mojo/std/builtin/builtin_slice/StridedSlice) types were added to the `builtin_slice` module to support specialization for slicing without strides. * `Span` now conforms to [`Iterable`](/mojo/std/iter/Iterable). * [`any()`](/mojo/std/builtin/bool/any) and [`all()`](/mojo/std/builtin/bool/all) now work over `Iterable`s, which means they can act over the result of [`map()`](/mojo/std/iter/map). * [`Tuple`](/mojo/std/builtin/tuple/Tuple)s have been improved: * Tuples can now be concatenated with `Tuple.concat()`. * Tuples can now be reversed with `Tuple.reverse()`. * The [`peekable()`](/mojo/std/iter/peekable) function has been added to [`iter`](/mojo/std/iter/). This allows users to peek at the next element of an iterator without advancing it. #### String and text {#26-1-string-and-text} * [`String`](/mojo/std/collections/string/string/String/) has had its UTF-8 guarantees strengthened. * It now has three separate constructors when converting raw bytes (`Span[Byte]`) to a `String`: * `String(from_utf8=...)`: Raises an error if the bytes are invalid UTF-8 * `String(from_utf8_lossy=...)`: Converts invalid UTF-8 byte sequences into the `(U+FFFD, �)` replacement character and does not raise an error. * `String(unsafe_from_utf8=...)`: Unsafely assumes the input bytes are valid UTF-8 without any checks. * `append_byte()` has been deprecated and has been replaced with `append()`. * The `strip()`, `lstrip()`, and `rstrip()` methods of `String` and [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice) now support stripping multi-byte unicode codepoints. Additionally `lstrip()` and `strip()` will no longer produce invalid UTF-8 if the "chars" string contains characters sharing their first byte with a character in the string to be stripped. * [`StringLiteral.format()`](/mojo/std/builtin/string_literal/StringLiteral/#format) now emits a compile-time constraint error if the format string is invalid (instead of a runtime error). ```mojo "Hello, {!invalid}".format("world") # note: constraint failed: Conversion flag "invalid" not recognized. ``` * `StringSlice.char_length()` has been renamed `count_codepoints()`. The same function was added to `String` and `StringLiteral`. * Added a [`CStringSlice`](/mojo/std/sys/ffi/cstring/CStringSlice/) as a type-safe way to interact with nul-terminated c-style strings (`const char*`). * Removed `String.join(*Writable)` overload that takes a variadic sequence of arguments, as it could be ambiguous with the remaining `String.join(Span[Writable])` overload. * Removed the `Int.__init__(self, value: StringSlice, base: UInt)` constructor. Users should call [`atol()`](/mojo/std/collections/string/string/atol/) directly. #### Pointer and memory {#26-1-pointer-and-memory} * [`OwnedDLHandle.get_symbol()`](/mojo/std/sys/ffi/OwnedDLHandle/#get_symbol) now returns [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) `[T, MutAnyOrigin]` instead of `UnsafePointer[T, ImmutAnyOrigin]`. The vast majority of symbols loaded from shared libraries are meant to be used mutably, and it's safer to go from mutable → immutable (via `.as_immutable()`) than from immutable → mutable (via `.unsafe_mut_cast[True]()`). Users who need immutable pointers can now simply call `.as_immutable()` on the result. * The "LegacyUnsafePointer" type has been changed to take its mutability as a first inferred parameter without a default, rather than a later explicit parameter with a default value of true. We recommend moving off of this type as soon as possible, but to roughly emulate the prior behavior, try out: ```mojo from memory import LegacyUnsafePointer comptime UnsafePointer = LegacyUnsafePointer[mut=True, *_, **_] ``` * External origins are now expressed using type level `{Mut,Immut,}ExternalOrigin` aliases instead of being spelled like `Origin[True].external`, improving consistency with other origin types. * `UnsafePointer` can now be initialized from a raw memory address using the `unsafe_from_address` initializer. * [`alloc()`](/mojo/std/memory/unsafe_pointer/alloc/) now has a [`debug_assert()`](/mojo/std/builtin/debug_assert/debug_assert) ensuring count is non-negative. #### Explicitly-destroyed types {#26-1-explicitly-destroyed-types} * Basic support for explicitly-destroyed types in the standard library is now available. Explicitly-destroyed types—sometimes referred to as "linear types"—are types that don't define a `__del__()` method that the compiler can call automatically to destroy an instance. Instead, a explicitly-destroyed type must provide a named method taking `deinit self` that the programmer is required to call explicitly whenever an owned instance is no longer used. The updated `AnyType` trait can be used in parameters to denote generic code that supports object instances that can't be implicitly destroyed. * `Span`, `UnsafePointer`, [`Pointer`](/mojo/std/memory/pointer/Pointer), and [`OwnedPointer`](/mojo/std/memory/owned_pointer/OwnedPointer) can point to explicitly-destroyed types. * Added `UnsafePointer.destroy_pointee_with()`, for destroying explicitly-destroyed types in-place using a function pointer to the type's destructor. * `List`, `InlineArray`, [`Optional`](/mojo/std/collections/optional/Optional), [`Variant`](/mojo/std/utils/variant/Variant), [`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem), and [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack) can now contain explicitly-destroyed types. * `Variant.take()` now takes `deinit self` instead of `mut self`. * Added `Variant.destroy_with()` for destroying an explicitly-destroyed type in-place by passing in the type's destructor function. * The `*args` language syntax for arguments now supports explicitly-destroyed types. * `Iterator.Element` no longer requires `ImplicitlyDestructible`. * [`UnsafeMaybeUninitialized`](/mojo/std/memory/maybe_uninitialized/UnsafeMaybeUninitialized) can now contain explicitly-destroyed types. #### Type system and traits {#26-1-type-system} * Using a new 'unconditional conformances' technique leveraging `conforms_to()` and [`trait_downcast()`](/mojo/std/builtin/rebind/trait_downcast) to perform "late" element type conformance checking, some standard library types are now able to conform to traits that they could not previously: * `List`, `Dict`, [`Set`](/mojo/std/collections/set/Set), [`Deque`](/mojo/std/collections/deque/Deque), `InlineArray`, and [`LinkedList`](/mojo/std/collections/linked_list/LinkedList) now conform to `Writable`, [`Stringable`](/mojo/std/builtin/str/Stringable), and [`Representable`](/mojo/std/builtin/repr/Representable). `List` also conforms to `Equatable`. * `Pointer`, [`ArcPointer`](/mojo/std/memory/arc_pointer/ArcPointer), and `OwnedPointer` now conform to `Writable`. * `Iterator`, `Tuple`, `Variant`, and `Optional` no longer require their element types to be `Copyable`. * The `Iterator` trait and for-each loop have removed the `__has_next__()` method and now use a `__next__()` method that `raises` [`StopIteration`](/mojo/std/iter/StopIteration). This follows Python precedent better, is more convenient to implement, and can be a minor performance win in some cases. * The `ImplicitlyBoolable` trait has been removed. This trait enabled types to implicitly convert to [`Bool`](/mojo/std/builtin/bool/Bool). This behavior was rarely used, and could lead to subtle bugs, for example mistakenly passing types like `Int` or `UnsafePointer` to an argument expecting a `Bool` would silently compile successfully. * The `Error` type no longer conforms to [`Boolable`](/mojo/std/builtin/bool/Boolable) or [`Defaultable`](/mojo/std/builtin/value/Defaultable). Errors must now be constructed with meaningful context, and optionality should be expressed through `Optional[Error]` rather than treating errors as boolean values. * The `Copyable` trait now refines the `Movable` trait. This means that structs and generic algorithms that already require `Copyable` don't need to also mention they require `Movable`. * The `EqualityComparable` trait has been deprecated in favor of `Equatable`, which has identical functionality. * We have removed [`Identifiable`](/mojo/std/builtin/identifiable/Identifiable) from enum-like types (such as [`DType`](/mojo/std/builtin/dtype/DType) and [`AddressSpace`](/mojo/std/memory/pointer/AddressSpace/)). This change is related to the idea that `Identifiable` is for comparing memory addresses. #### Python interoperability {#26-1-python-interop} * The `ConvertibleFromPython` trait and associated initializers now have a required keyword argument. Before: `Int(pyObj)`. After: `Int(py=pyObj)`. This avoids ambiguities in cases where either multiple overloads could apply, or where implicit conversions to `PythonObject` could mask that a Python operation was happening. * [`PythonObject`](/mojo/std/python/python_object/PythonObject) now supports implicit conversion from `None`, allowing more natural Python-like code: ```mojo var obj: PythonObject = None # Now works without explicit PythonObject(None) fn returns_none() -> PythonObject: return None # Implicit conversion ``` #### File I/O and OS {#26-1-file-io} * Basic file I/O operations in the `io` module are now implemented natively in Mojo using direct `libc` system calls (`open()`, `close()`, `read()`, `write()`, `lseek()`). The [`FileHandle`](/mojo/std/io/file/FileHandle) type no longer depends on CompilerRT functions, providing better performance and transparency. Error handling now includes errno-based messages for improved diagnostics. * The `os.process` submodule has been added with utilities to spawn and wait on processes. These use `posix_spawn()` and do not go through the system shell. * Various OS wrapper functions now include the value of `errno` in the raised error message. * The `os` module now exposes a [`link()`](/mojo/std/os/os/link) function, wrapping the unix `link(2)` system call. * The `os` module now exposes a [`symlink()`](/mojo/std/os/os/symlink) function, wrapping the unix `symlink(2)` syscall. #### GPU programming {#26-1-gpu} * [`Layout`](/mojo/kernels/layout/layout/Layout) no longer conforms to [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable/). The motivation for this change is that it was easy to accidentally materialize a `Layout` at runtime using methods such as `Layout.size()`. There are a few downstream changes users will have to adapt to with this change. Below are some common patterns which will help with this transition: ```mojo from layout import Layout fn foo(a: Layout) -> Layout: return a.copy() # `a` needs to be explicitly copied. fn bar(var a: Layout) -> Layout: return a^ # If a is taken by value, one can use the transfer operator. fn baz(): comptime a = Layout.row_major(4, 4) # `a` needs to be materialized because `foo` returns a type that # is not ImplicitlyCopyable (`Layout`). var b = foo(materialize[a]()) # `bar` moves `b` by value, but `foo` also takes `b` by reference, # so we need to explicitly copy. _ = bar(b.copy()) _ = foo(b) # Since we are no longer using `b`, it's fine to take it by value here. _ = bar(b^) # Since `Layout.size()` returns an `Int`, we can use a `comptime` # expression to compute the return value without materializing `a`. for i in range(comptime (a.size())): ... ``` * `DeviceContext.enqueue_function_checked()` and `DeviceStream.enqueue_function_checked()` have been renamed to `enqueue_function()`. Similarly, `DeviceContext.compile_function_checked()` has been renamed to `compile_function()`. * `DeviceContext.enqueue_function()` and `DeviceContext.enqueue_function_experimental()` now automatically infer `func_attribute` to `FuncAttribute.MAX_DYNAMIC_SHARED_SIZE_BYTES(shared_mem_bytes)` when `shared_mem_bytes` is specified but `func_attribute` is not, for NVIDIA GPUs with allocations > 48KB. This eliminates the need to specify the same shared memory size twice in many cases, reducing boilerplate and preventing mismatched values. On AMD GPUs or for allocations ≤ 48KB, explicit `func_attribute` values should be provided when needed. * The `inlined_assembly()` function is now publicly exported from the `sys` module, allowing users to embed raw assembly instructions directly into Mojo code. This provides fine-grained control over hardware operations using LLVM-style inline assembly syntax. Example: ```mojo from sys import inlined_assembly # Convert bfloat16 to float32 on NVIDIA GPU using PTX assembly. var result = inlined_assembly[ "cvt.f32.bf16 $0, $1;", Float32, constraints="=f,h", has_side_effect=False, ](my_bf16_as_int16) ``` #### Formatting and output {#26-1-formatting} * [`Writer`](/mojo/std/format/Writer) has been reworked to support only UTF-8 data instead of arbitrary `Byte` sequences. The `write_bytes()` method has been replaced with `write_string()`. * In line with these changes, `String`'s `write_bytes()` method has also been deprecated, and its initializer `__init__(out self, *, bytes: Span[Byte])` has had its keyword argument renamed to `unsafe_from_utf8`. This brings it more in line with the existing `StringSlice` constructors and explicitly states that construction from arbitrary bytes is inherently unsafe. * `Writer` and `Writable` have been moved into a new `format` module and out of `io`. These traits are not directly related to binary i/o, but are rather closely tied to type/value string formatting. * The `Writable` trait now supports debug formatting through an optional `write_repr_to()` method, called by [`repr()`](/mojo/std/builtin/repr/repr) and the `{!r}` format specifier. Additionally, `repr()` and string formatting methods (`.format()` on `String`, `StringSlice`, and `StringLiteral`) now accept `Writable` types, enabling efficient formatting without intermediate string allocations. To preserve existing behavior, types implementing both `Stringable & Representable` and `Writable` will continue using `Stringable & Representable` methods; only types implementing `Writable` alone will use the new code paths. * [`Counter`](/mojo/std/collections/counter/Counter) now conforms to `Writable`, `Stringable`, and `Representable`. * [`black_box()`](/mojo/std/benchmark/compiler/black_box) has been added to the `benchmark` utilities as a way to prevent the compiler from aggressively optimizing out values. Similar to [`keep()`](/mojo/std/benchmark/compiler/keep), however, it returns its argument. #### Other library changes {#26-1-other-library} * The [`UInt`](/mojo/std/builtin/simd/#uint) struct has been replaced by a new `UInt` type alias to `Scalar[DType.uint]`. This is a major change that enables more powerful generic programming by abstracting over unsigned SIMD data dtypes of machine word size. This change will likely break code that relies on implicit conversions to/from `UInt` and `Int`/`SIMD`. The `SIMD` type is also slightly less foldable at compile time, which can cause some code in where clauses, comptime expressions, and other code in parametric contexts to now fail or crash. These shortcomings will be addressed in subsequent patches as needed. * Implicit conversion between `Int` and `UInt` has been removed. * The [`random`](/mojo/std/random/) module now uses a pure Mojo implementation based on the Philox algorithm (via an internal wrapper), replacing the previous `CompilerRT` C++ dependency. The Philox algorithm provides excellent statistical quality, works on both CPU and GPU, and makes random number generation fully transparent and source-available. Note that this changes the random number sequence for a given seed value, which may affect tests or code relying on reproducible sequences. * `StringableRaising` has been deprecated and its usages in the stdlib have been removed. * `Variadic` now has `zip_types()`, `zip_values()`, and `slice_types()`. ### Tooling changes {#26-1-tooling-changes} * The Mojo compiler now supports the `-Werror` flag, which treats all warnings as compilation errors. This is useful for enforcing stricter code quality standards, particularly in CI/CD pipelines. The flag works with the Mojo compiler tools (`mojo run`, `mojo build`, `mojo package`, `mojo doc`). When used with `--disable-warnings`, warnings are promoted to errors first, so the errors are not suppressed. * The counterpart `-Wno-error` flag disables treating warnings as errors. When both flags are specified, the last one wins. * Specifying CUDA architectures with `--target-accelerator` now expects a sm version string rather than just a compute capability. For example, `--target-accelerator=nvidia:80` should be changed to `--target-accelerator=nvidia:sm_80`. If an incorrect format is used for the version, the compiler will default to the lowest supported sm version. * The Mojo LSP server now debounces document updates to reduce CPU usage during rapid typing. Previously, every keystroke triggered a full document parse; now updates are coalesced with a 150ms delay, reducing parse frequency by 10-50x during active editing. * The Mojo compiler now "diffs" very long types in error messages to explain what is going on in an easier to understand way. * Elaboration error printing with different levels of verbosity offers control on how parameter values are displayed as part of elaboration errors when function instantiation fails. `--elaboration-error-verbose=value` now takes a value, where: * `no-params` means don't display any concrete parameter values. This is helpful to collapse recursion-related error messages into shorter blobs. * `simple-params` means display concretized parameter values for simple types, including numeric types and strings, in a user-friendly format (default value). * `all-params` means show all concrete parameter values. This is for advanced programmers who don't mind reading MLIR attributes but want more visibility of parameter values. * `--elaboration-max-depth` is added to control maximum elaborator instantiation depth. This (unsigned) value helps to detect compile time recursion. The default is `std::numeric_limits::max()`. * Docstring validation with `--validate-doc-strings` now emits an error when an `fn` function is declared to raise an error (`raises`) but it's missing a [`Raises` docstring](https://github.com/modular/modular/blob/main/mojo/stdlib/docs/docstring-style-guide.md#errors) (previously it emitted only a warning). Because Mojo automatically treats all `def` functions as [raising functions](/mojo/manual/functions#raising-and-non-raising-functions), we don't enforce `Raises` docs for `def` functions (to avoid noisy false positives). * Docstring validation now includes `comptime` aliases. The `--diagnose-missing-doc-strings` flag now checks that public aliases have properly formatted docstrings (summary ends with period, starts with capital letter). Parametric aliases are also checked for proper `Parameters:` sections. * The `--validate-doc-strings` flag has been deprecated for `mojo doc` and removed from other tools (`mojo build`, `mojo run`, `mojo package`). Use `-Werror` instead to treat warnings as errors. * The Mojo compiler now supports the `-Xlinker` flag to pass options on directly to the linker, for example: ```console mojo build -Xlinker -lfoo main.mojo ``` Note: this option only has an effect with `mojo build`. With `mojo run`, the arguments are ignored and a warning is issued. * The Mojo compiler now supports the `--experimental-export-fixit` flag for `mojo build`, `mojo run`, and `mojo package`. This flag exports fix-its to a YAML file compatible with `clang-apply-replacements`, instead of applying them directly. This is useful for integrating Mojo's fix-it suggestions into external tooling workflows. The flag is mutually exclusive with `--experimental-fixit` (which applies fix-its directly). * The Mojo Debugger `mojo break-on-raise` feature now works correctly with multiple targets in a debugger instance. The setting is per-target. ### Experimental changes {#26-1-experimental-changes} Changes described in this section are experimental and may be changed, replaced, or removed in future releases. * Mojo now supports compile-time trait conformance check (via `conforms_to()`) and downcast (via `trait_downcast()`). This allows users to implement features like static dispatching based on trait conformance. For example: ```mojo fn maybe_print[T : AnyType](maybe_printable : T): @parameter if conforms_to(T, Writable): print(trait_downcast[Writable](maybe_printable)) else: print("[UNPRINTABLE]") ``` * Added support for [`DType`](/mojo/std/builtin/dtype/DType) expressions in `where` clauses: ```mojo fn foo[dt: DType]() -> Int where dt == DType.int32: return 42 ``` Currently, the following expressions are supported: * equality and inequality * `is_signed()`, `is_unsigned()`, `is_numeric()`, `is_integral()`, `is_floating_point()`, `is_float8()`, `is_half_float()` * Added support for [`SIMD`](/mojo/std/builtin/simd/SIMD) expressions in `where` clauses: ```mojo fn foo[dt: DType, x: Int]() -> Int where SIMD[dt, 4](x) + 2 > SIMD[dt, 4](0): return 42 ``` Currently, the following expressions are supported: * default construction and construction from `Int` and [`IntLiteral`](/mojo/std/builtin/int_literal/IntLiteral/) * equality, inequality, and other comparison operators * addition, subtraction, and multiplication * bitwise logical operations, excluding shifts ### Removed {#26-1-removed} * The following deprecated GPU compatibility modules have been removed: * `gpu.id` - Use `from gpu import block_idx, thread_idx, ...` instead * `gpu.block` - Use `from gpu.primitives.block import ...` instead * `gpu.warp` - Use `from gpu.primitives.warp import ...` instead * `gpu.cluster` - Use `from gpu.primitives.cluster import ...` instead * `gpu.grid_controls` - Use `from gpu.primitives.grid_controls import ...` instead * `gpu.mma` - Use `from gpu.compute.mma import ...` instead * `gpu.mma_operand_descriptor` - Use `from gpu.compute.mma_operand_descriptor import ...` instead * `gpu.mma_util` - Use `from gpu.compute.mma_util import ...` instead * `gpu.mma_sm100` - Use `from gpu.compute.arch.mma_nvidia_sm100 import ...` instead * `gpu.semaphore` - Use `from gpu.sync.semaphore import ...` instead * `gpu.tcgen05` - Use `from gpu.compute.arch.tcgen05 import ...` instead * The DeviceContext `enqueue_function_unchecked()` and `compile_function_unchecked()` have been removed. Please migrate the code to use `enqueue_function()` and `compile_function()`. * `NDBuffer` has been removed. Use [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/) instead. * The `UnsafePointer.offset()` method is now deprecated. Use pointer arithmetic instead: ```mojo # Before new_ptr = ptr.offset(n) # After new_ptr = ptr + n ``` ### Fixed {#26-1-fixed} * Several reflection-related compiler crashes have been fixed: * [Issue #5731](https://github.com/modular/modular/issues/5731): Reflection functions now work correctly on builtin types like `Int`, `NoneType`, and `Origin`. * [Issue #5732](https://github.com/modular/modular/issues/5732): `get_type_name()` now handles types with constructor calls in their parameters (like `A[B(True)]`) when extracted via `struct_field_types()`. * [Issue #5723](https://github.com/modular/modular/issues/5723): `get_type_name()` now handles nested parametric types from `struct_field_types()`. * [Issue #5754](https://github.com/modular/modular/issues/5754): `struct_field_type_by_name()` now works correctly when using `ReflectedType.T` as a type annotation. * [Issue #5808](https://github.com/modular/modular/issues/5808): `rebind()` and `rebind_var()` now accept downcasted types from `struct_field_types()`, allowing patterns like `rebind_var[types[i]](downcast[types[i], Trait]()^)`. * [Issue #5618](https://github.com/modular/modular/issues/5618): Compiler crash when should be implicit conversion error. * [Issue #5361](https://github.com/modular/modular/issues/5361): mojo doc crashes on alias of parametrized function with origin. * [Issue #5137](https://github.com/modular/modular/issues/5137): Tail call optimization doesn't happen for tail recursive functions with raises. * [Issue #5138](https://github.com/modular/modular/issues/5138): Tail call optimization doesn't happen for functions with local stack temporaries. * Mojo no longer complains about "cannot infer parameter X" when unrelated type checking errors happen in complex parametric code. It now gives much more useful and actionable error messages in these cases. * [`time.sleep()`](/mojo/std/time/time/sleep) now works correctly for durations longer than 1 millisecond on NVIDIA GPUs. Previously, sleep durations were silently capped at 1ms due to a hardware limitation in the underlying `nanosleep` intrinsic. AMD GPUs now have basic sleep support using the `s_sleep` instruction, which is sufficient for spin-wait backoff operations though it doesn't provide accurate wall-clock timing. Additionally, `global_perf_counter_ns()` is now exported from the `time` package for GPU code that needs nanosecond-resolution timing. * [Issue #5578](https://github.com/modular/modular/issues/5578): ownership overloading not working when used with `ref`. * [Issue #1850](https://github.com/modular/modular/issues/1850): Mojo assumes string literal at start of a function is a doc comment * [Issue #4501](https://github.com/modular/modular/issues/4501): Incorrect parsing of incomplete assignment * [Issue #4765](https://github.com/modular/modular/issues/4765): Parser accepts pointless var ref a = n binding form * [`Codepoint.unsafe_decode_utf8_codepoint()`](/mojo/std/collections/string/codepoint/Codepoint/#unsafe_decode_utf8_codepoint) no longer returns `Codepoint(0)` (NUL) when passed an empty span. Instead, a `debug_assert()` now enforces the requirement that the input span be non-empty, consistent with the function's existing safety contract. * [Issue #5635](https://github.com/modular/modular/issues/5635): [`Deque`](/mojo/std/collections/deque/Deque) shrink reallocation incorrectly handled empty deque with `capacity > min_capacity`. ### Special thanks Special thanks to our community contributors: Alex Maldonado ([@aalexmmaldonado](https://github.com/aalexmmaldonado)), Bernhard Merkle ([@bmerkle](https://github.com/bmerkle)), Brian Grenier ([@bgreni](https://github.com/bgreni)), Christoph Schlumpf ([@christoph-schlumpf](https://github.com/christoph-schlumpf)), Duba Sirisha ([@SirishaDuba](https://github.com/SirishaDuba)), Ethan Wu ([@YichengDWu](https://github.com/YichengDWu)), Gunasekar ([@sdgunaa](https://github.com/sdgunaa)), Hristo (Izo) G. ([@izo0x90](https://github.com/izo0x90)), Johannes Laute ([@jaidmin](https://github.com/jaidmin)), Jordan Rinder ([@jrinder42](https://github.com/jrinder42)), josiahls ([@josiahls](https://github.com/josiahls)), Krish Gupta ([@KrxGu](https://github.com/KrxGu)), Magi Sharma J ([@magi8101](https://github.com/magi8101)), Manuel Saelices ([@msaelices](https://github.com/msaelices)), martinvuyk ([@martinvuyk](https://github.com/martinvuyk)), Richard Johnsson ([@jmikaelr](https://github.com/jmikaelr)), Ritesh Goru ([@BlackWingedKing](https://github.com/BlackWingedKing)), Ross Campbell ([@RossCampbellDev](https://github.com/RossCampbellDev)), RWayne93 ([@RWayne93](https://github.com/RWayne93)), soraros ([@soraros](https://github.com/soraros)), Sören Brunk ([@sbrunk](https://github.com/sbrunk)), turakz ([@turakz](https://github.com/turakz)), Valentin Erokhin ([@saviorand](https://github.com/saviorand)), YeonguChoe ([@YeonguChoe](https://github.com/YeonguChoe)) ## v0.25.7 (2025-11-20) ### ✨ Highlights * A new, improved [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) type has been added, and the old version renamed to [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer). The new version fixes several issues with the old `LegacyUnsafePointer`. * `LegacyUnsafePointer` had constructors that allowed unsafe implicit mutability and origin casts, making it easy to make unsafe changes by accident. The new `UnsafePointer` eliminates these. * The new `UnsafePointer` now has an inferred mutability parameter, simplifying the API. * `UnsafePointer` does not have a default value for its origin parameter, so it must be explicitly specified or unbound. `LegacyUnsafePointer` defaults to `MutAnyOrigin`, which results in the lifetimes of many values being extended needlessly. For more information, see [Pointer and memory changes](#25-7-pointer-and-memory-changes). * Many enhancements to compile-time errors, including suppressing extraneous messages and preserving more human-readable display for many aliases and parameters. For details, see [Tooling changes](#25-7-tooling-changes). * Added a new document on [GPU block and warp operations and synchronization](/mojo/manual/gpu/block-and-warp). ### Language enhancements {#25-7-language-enhancements} * Mojo now supports the `comptime` keyword as a synonym for `alias`. The `comptime` keyword can be used interchangeably with `alias` for compile-time declarations. Both keywords are fully supported and produce identical behavior. For example: ```mojo comptime x = 5 # New preferred syntax alias y = 10 # Still fully supported comptime MyType[T: AnyType] = T # Works with parametric declarations ``` Note: Future updates will migrate error messages and internal terminology to use "comptime". The `alias` keyword will remain supported for backward compatibility for now. * Mojo now supports unpacking an alias/comptime tuple with a single statement when it is not inside a struct or trait. For example: ```mojo comptime i, f = (1, 3.0) comptime q, v = divmod(4, 5) ``` * [Issue #3925](https://github.com/modular/modular/issues/3925): Mojo now allows methods to be overloaded based on "owned" vs "by-ref" argument conventions, selecting the owned overload when given an owned value, and selecting the by-ref version otherwise. This allows somewhat more efficient algorithms, e.g. consuming vs borrowing iterators: ```mojo struct MyCollection: fn __iter__(var self) -> Self.ConsumingIterator: ... fn __iter__(self) -> Self.BorrowingIterator: ... ``` * Collection literals now have a default type. For example, you can now bind `[1,2,3]` to `T` in a call to a function defined as `fn zip[T: Iterable](impl:T)` because it will default to the standard library's `List` type. * Mojo now has a `__functions_in_module()` experimental intrinsic that allows reflection over the functions declared in the module where it is called. For example: ```mojo fn foo(): pass def bar(x: Int): pass def main(): alias funcs = __functions_in_module() # equivalent to: alias same_funcs = Tuple(foo, bar) ``` The intrinsic is currently limited for use from within the `main()` function. For an example of using `__functions_in_module()` in a test suite, see [Running tests with TestSuite](/mojo/tools/testing/#running-tests-with-testsuite) * The `@implicit` decorator now accepts an optional `deprecated` keyword argument. This can be used to phase out implicit conversions instead of just removing the decorator (which can result in another, unintended implicit conversion path). For example, the compiler now warns about the following: ```mojo struct MyStuff: @implicit(deprecated=True) fn __init__(out self, value: Int): pass fn deprecated_implicit_conversion(): # warning: deprecated implicit conversion from 'IntLiteral[1]' to 'MyStuff' _: MyStuff = 1 _ = MyStuff(1) # this is okay, because the conversion is already explicit. ``` * The `@deprecated` decorator can now take a target symbol with the `use` keyword argument. This is mutually exclusive with the existing positional string argument. A deprecation warning will be automatically generated. ```mojo @deprecated(use=new) fn old(): pass fn new(): pass fn main(): old() # 'old' is deprecated, use 'new' instead ``` * In struct instances that declare a parametric `__call__()` method, but not one of the subscript methods (`__getitem__()`, `__setitem__()`, or `__getattr__()`), the `__call__()` method can now be invoked with parameters: ```mojo struct Callable: fn __init__(out self): pass fn __call__[x: Int](self, y: Int) -> Int: return x + y fn main(): var c = Callable() print(c[1](2)) # 3 ``` Previously you would have needed to explicitly look up `__call__()`: ```mojo print(c.__call__[1](2)) ``` * Added `DType.float4_e2m1fn` as the 4bit float `e2m1` format. This `Float4_e2m1` type is defined by the [Open Compute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). * `deinit` methods may now transfer all of `self` to another `deinit` method. * Mojo now uses system allocators in programs built with `mojo build --sanitize address`. This means ASan can see Mojo heap allocations and should now be able to detect many more heap memory errors. ### Language changes {#25-7-language-changes} * The `__type_of()` magic function has been renamed to `type_of()`. Using the old spelling will yield an error. Similarly, `__origin_of()` has been renamed to `origin_of()`. * An expression like `(Int, Float64)` is no longer syntax sugar for a tuple types like `Tuple[Int, Float64]`. ### Library changes {#25-7-library-changes} #### Pointer and memory changes {#25-7-pointer-and-memory-changes} * `UnsafePointer` has been renamed to [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer/) and a new [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer/) has [taken its place](https://forum.modular.com/t/proposal-unsafepointer-v2/2411?u=nate). Similarly, `OpaquePointer` has been renamed to [`LegacyOpaquePointer`](/mojo/std/memory/legacy_unsafe_pointer/#legacyopaquepointer) and a new [`OpaquePointer`](/mojo/std/memory/unsafe_pointer/#opaquepointer) has taken its place. The primary difference is the ordering of parameters, which now looks like this: ```mojo struct UnsafePointer[ mut: Bool, //, # Inferred mutability type: AnyType, origin: Origin[mut], # Non-defaulted origin *, address_space: AddressSpace = AddressSpace.GENERIC, ] alias OpaquePointer[ mut: Bool, //, # Inferred mutability origin: Origin[mut], # Non-defaulted origin *, address_space: AddressSpace = AddressSpace.GENERIC, ] = UnsafePointer[NoneType, origin, address_space=address_space] ``` Its implicit constructors now no longer allow for unsafe casting between pointers with different mutability and origin values. Code will need to update to the new `UnsafePointer`, however, in the interim, users can find-and-replace their current usages of `UnsafePointer` and rename them to `LegacyUnsafePointer`. Another option is users can add the following import statement to the beginning of any files relying on the old pointer type: ```mojo from memory import LegacyUnsafePointer as UnsafePointer, **_] # and/or if you use OpaquePointer from memory import LegacyOpaquePointer as OpaquePointer ``` Users can also use the `as_legacy_pointer()` and `as_unsafe_pointer()` conversion methods to convert between the two pointer types during this migration period. *Note*: `LegacyUnsafePointer` and `LegacyOpaquePointer` will eventually be deprecated and removed in a future version of Mojo. There are a few new helpful type aliases for `UnsafePointer`: * `MutUnsafePointer` * `ImmutUnsafePointer` * `MutOpaquePointer` * `ImmutOpaquePointer` Lastly, [`alloc()`](/mojo/std/memory/unsafe_pointer/alloc) has been moved from a static method on `UnsafePointer` to a free standing `alloc()` function. Therefore, code that was written as: ```mojo var ptr = UnsafePointer[Int].alloc(3) ``` Must be changed to: ```mojo var ptr = alloc[Int](3) ``` For more details on migrating to the new `UnsafePointer`, see the migration guide provided in the [UnsafePointer v2 proposal](https://github.com/modular/modular/blob/main/mojo/proposals/unsafe-pointer-v2.md#migration-guide-from-legacyunsafepointer-to-the-new-unsafepointer). * Added a `swap_pointees()` function to `UnsafePointer` as an alternative to `swap()` when the pointers may potentially alias each other. * `origin_cast()` for [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/), `NDBuffer` and `UnsafePointer` has been deprecated and removed. `LayoutTensor` and `NDBuffer` now support a safer `as_any_origin()` origin casting method. `UnsafePointer` has the same safe alternative and in addition, it has a safe `as_immutable()` casting function and explicitly unsafe `unsafe_mut_cast()` and `unsafe_origin_cast()` casting methods. * The `empty` origin has been renamed to `external`. This origin represents a value that's not tracked by the Mojo lifetime checker. Newly-allocated memory from `alloc()` is returned with the origin `MutOrigin.external`. * Renamed `MutableOrigin` to `MutOrigin` and `ImmutableOrigin` to `ImmutOrigin`. * Renamed `MutableAnyOrigin` to `MutAnyOrigin` and `ImmutableAnyOrigin` to `ImmutAnyOrigin`. * [`memcpy()`](/mojo/std/memory/memory/memcpy/) and [`parallel_memcpy()`](/mojo/std/algorithm/memory/parallel_memcpy/) without keyword arguments are deprecated. #### Collections and iterators * [`Span`](/mojo/std/memory/span/Span/) and [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/) constructors now accept `Int` for length parameters instead of `UInt`. This change makes these types more ergonomic to use with integer literals and other `Int`-based APIs. * Added `Span.binary_search_by()` which allows binary searching with a custom comparator function. * Added `unsafe_get()`, `unsafe_swap_elements()` and `unsafe_subspan()` to the `Span` struct. * [`Optional`](/mojo/std/collections/optional/Optional/) now conforms to `Iterable` and `Iterator` acting as a collection of size 1 or 0. * New [`ContiguousSlice`](/mojo/std/builtin/builtin_slice/ContiguousSlice) and [`StridedSlice`](/mojo/std/builtin/builtin_slice/StridedSlice) types were added to the `builtin_slice` module to support specialization for slicing without strides. * [`List`](/mojo/std/collections/list/List/) slicing without a stride now returns a `Span`, instead of a `List` and no longer allocates memory. * Several standard library APIs have been updated to use `Int` instead of `UInt` for improved ergonomics, eliminating the need for explicit casts when using `Int` values (the default type for integer literals and loop indices): * `BitSet[size: Int]` - Changed parameter from `UInt` to `Int` * `BitSet.set(idx: Int)`, `BitSet.clear(idx: Int)`, `BitSet.toggle(idx: Int)`, `BitSet.test(idx: Int)` - Changed from `UInt` to `Int` * `String(unsafe_uninit_length: Int)` - Changed from `UInt` to `Int` * `String.capacity() -> Int` - Changed return type from `UInt` to `Int` * `String.reserve(new_capacity: Int)` - Changed from `UInt` to `Int` * `List(length: Int, fill: T)` - Changed from `UInt` to `Int` * `Codepoint.unsafe_write_utf8() -> Int` - Changed return type from `UInt` to `Int` * `Codepoint.utf8_byte_length() -> Int` - Changed return type from `UInt` to `Int` * Added [`repeat()`](/mojo/std/itertools/itertools/repeat/) function to the `itertools` module that creates an iterator which repeats an element a specified number of times. Unlike Python's `itertools.repeat()`, infinite iteration is not currently supported - the `times` parameter is required. Example usage: ```mojo from itertools import repeat for val in repeat(42, times=3): print(val) # Prints: 42, 42, 42 ``` * Tuples now support comparison operations if the element types are also comparable. For example, one can now write `(1, "a") == (1, "a")` or `(1, "a") < (1, "b")`. #### String and text * [`Codepoint`](/mojo/std/collections/string/codepoint/Codepoint/) now conforms to [`Writable`](/mojo/std/format/Writable/). * `Codepoint` now conforms to [`Comparable`](/mojo/std/builtin/comparable/Comparable/) adding `__le__()`, `__lt__()`, `__ge__()`, and `__gt__()` implementations. #### Math and numeric types * The deprecated `DType.index` is now removed in favor of the [`DType.int`](/mojo/std/builtin/dtype/DType/#int). * Added `DType.float4_e2m1fn` as the 4bit float `e2m1` format. This `Float4_e2m1` type is defined by the [Open Compute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). * `math.isqrt()` has been renamed to [`rsqrt()`](/mojo/std/math/math/rsqrt/) since it performs reciprocal square root functionality. * The [`math`](/mojo/std/math/math/) package now has a Mojo native implementation of `acos()`, `asin()`, `cbrt()`, and `erfc()`. * [`SIMD`](/mojo/std/builtin/simd/SIMD/) now implements the [`DivModable`](/mojo/std/builtin/math/DivModable/) trait. * Implicit conversions between `Int` and `UInt` are now deprecated. The `@implicit` decorator on `Int.__init__(UInt)` and `UInt.__init__(Int)` will be removed in a future version of Mojo. Code that currently performs implicit conversions between `Int` and `UInt` will issue a deprecation warning, and should be updated to explicitly read `Int(uint_val)` or `UInt(int_val)` respectively. * The `ImplicitlyIntable` trait has been removed. Types implementing this trait could be implicitly converted to `Int`. `Bool` was the only Mojo standard library type to implement `ImplicitlyIntable`. Conversions from `Bool` to `Int` can now be performed explicitly, using `Int(bool-val)` (via the remaining `Intable` trait, which only supports *explicit* conversions). #### GPU support * Added support for NVIDIA GeForce GTX 970. * [`gpu.sync.syncwarp()`](/mojo/std/gpu/sync/sync/syncwarp/) now supports Apple GPUs via `SIMDGROUP` barrier implementation. On Apple GPUs, this provides execution synchronization for all active lanes using a `SIMDGROUP` barrier with no memory fence. For threadgroup memory ordering, use [`barrier()`](/mojo/std/gpu/sync/sync/barrier/) instead. Note that lane masks are not supported on Apple GPUs, so the mask argument is ignored. * [`gpu.primitives.warp`](/mojo/std/gpu/primitives/warp/) now supports Apple GPUs with native SIMD-group shuffle operations. This enables `shuffle_idx()`, `shuffle_up()`, `shuffle_down()`, and `shuffle_xor()` on Apple hardware by mapping Metal `simd_shuffle*` intrinsics to AIR (`llvm.air.simd_shuffle[_up/_down/_xor]`) instructions, achieving feature parity with NVIDIA and AMD backends. * [`gpu.intrinsics.store_release()`](/mojo/std/gpu/intrinsics/store_release/) and [`gpu.intrinsics.load_acquire()`](/mojo/std/gpu/intrinsics/load_acquire/) now support Apple silicon GPUs, expanding support for proper memory synchronization on these devices. * The [`gpu`](/mojo/std/gpu/) package has been reorganized into logical subdirectories for better code organization: * `gpu/primitives/` - Low-level GPU execution primitives (warp, block, cluster, id, grid\_controls) * `gpu/memory/` - Memory operations (async\_copy, TMA, address spaces) * `gpu/sync/` - Synchronization primitives (barriers, semaphores) * `gpu/compute/` - Compute operations (mma, tensor cores, tcgen05) **Backward compatibility**: All existing imports continue to work unchanged. Deprecated import paths (`gpu.id`, `gpu.mma`, `gpu.cluster`, `gpu.grid_controls`, `gpu.warp`, `gpu.semaphore`, `gpu.mma_sm100`, `gpu.tcgen05`, `gpu.mma_util`, `gpu.mma_operand_descriptor`, and `gpu.tensor_ops`) are preserved as re-export wrappers with deprecation notices. Users can migrate to the new recommended import patterns at their own pace: ```mojo # Old (deprecated but still works): from gpu.id import block_idx, thread_idx from gpu.mma import mma from gpu.mma_sm100 import UMMAKind from gpu.tcgen05 import tcgen05_alloc from gpu.semaphore import Semaphore from gpu.cluster import cluster_sync # New (recommended): from gpu import block_idx, thread_idx, cluster_sync from gpu.compute.mma import mma from gpu.compute.mma_sm100 import UMMAKind from gpu.compute.tcgen05 import tcgen05_alloc from gpu.sync.semaphore import Semaphore ``` * The `_GPUAddressSpace` type has been removed and consolidated into [`AddressSpace`](/mojo/std/memory/pointer/AddressSpace/). GPU-specific address space constants (`GLOBAL`, `SHARED`, `CONSTANT`, `LOCAL`, `SHARED_CLUSTER`) are now available as aliases on the unified `AddressSpace` type. The `GPUAddressSpace` alias has also been removed in favor of using `AddressSpace` directly. Since `AddressSpace` is part of the prelude, it no longer needs to be explicitly imported in most code. * TMA (Tensor Memory Accelerator) types have been moved to a dedicated module. The types `TMADescriptor`, `TensorMapSwizzle`, `TensorMapDataType`, `TensorMapInterleave`, `TensorMapL2Promotion`, `TensorMapFloatOOBFill`, and the functions `create_tma_descriptor()` and `prefetch_tma_descriptor()` are now available from [`gpu.host.nvidia.tma`](/mojo/std/gpu/host/nvidia/tma/) instead of `gpu.host._nvidia_cuda`. #### Testing {#25-7-testing} * [`assert_equal()`](/mojo/std/testing/testing/assert_equal) now displays colored character-by-character diffs when string comparisons fail, making it easier to spot differences. Differing characters are highlighted in red for the left string and green for the right string. * The `mojo test` command has been removed. The recommended testing strategy is to define test functions, call them explicitly from `main()` (or use the new `TestSuite` framework), and run with `mojo run`. * [`TestSuite`](/mojo/std/testing/suite/TestSuite/) now can generate test reports with `.generate_report()`. Also, `TestReport` and `TestSuiteReport` structs were added. * `TestSuite` now allows explicitly skipping registered tests using the `TestSuite.skip()` method. * `TestSuite` now allows basic control from CLI arguments. Tests can be skipped from the CLI by passing test function names after a `--skip` flag, e.g. ```console mojo run test_my_stuff.mojo --skip test_currently_failing test_also_failing ``` Similarly, the `--only` flag enables the specification of an allowlist, e.g. the following will skip any other registered test cases: ```console mojo run test_my_stuff.mojo --only test_only_this test_this_as_well ``` The `--skip-all` flag will skip all registered test cases in the suite. Note that `--only` respects skipped tests, i.e. it does not run tests that are skipped using `TestSuite.skip()`. #### System and OS * Added [`os.isatty()`](/mojo/std/os/os/isatty/) function to check whether a file descriptor refers to a terminal. This function accepts an `Int` file descriptor. If you have a [`FileDescriptor`](/mojo/std/io/file_descriptor/FileDescriptor/) object, use its `isatty()` method instead. * Added [`sys.compile.SanitizeAddress`](/mojo/std/sys/compile/#sanitizeaddress) providing a way for Mojo code to detect `--sanitize address` at compile time. * `DLHandle` is no longer part of the public API. To access a dynamically-linked library from Mojo, use [`OwnedDLHandle`](/mojo/std/sys/ffi/OwnedDLHandle/) instead. The new type provides RAII-based automatic resource management for dynamically-linked libraries. `DLHandle` has been renamed to `_DLHandle` and remains available internally for use by the standard library. #### Performance optimizations * Optimized float-to-string formatting performance by eliminating unnecessary stack allocations. Internal lookup tables used for float formatting (`cache_f32` and `cache_f64`) are now stored as global constants instead of being materialized on the stack for each conversion. This reduces stack overhead by \~10KB for `Float64` and \~600 bytes for `Float32` operations, improving performance for all float formatting operations including `print()`, string interpolation, and `str()` conversions. * Optimized number parsing performance by eliminating stack allocations for large lookup tables. Internal lookup tables used for number parsing (`powers_of_5_table` and `POWERS_OF_10`) are now stored as global constants using the [`global_constant`](/mojo/std/builtin/globals/global_constant/) function instead of being materialized on the stack for each parsing operation. This reduces stack overhead by \~10.6KB for number parsing operations, improving performance for string-to-number conversions including `atof()` and related float parsing operations. #### Other library changes * The [`Hasher`](/mojo/std/hashlib/hasher/Hasher/) trait's `_update_with_bytes()` method now takes `Span[Byte]` instead of `UnsafePointer[UInt8]` and a separate length parameter. This change applies to all hasher implementations including `AHasher` and `Fnv1a`. * The Philox random number generator (`Random` and `NormalRandom`) has been moved from `gpu.random` to [`random.philox`](/mojo/std/random/philox/). These types now work on both CPU and GPU. Import them using `from random import Random, NormalRandom` or `from random.philox import Random, NormalRandom`. ### Tooling changes {#25-7-tooling-changes} * Error and warning messages now preserve `comptime`/alias names in many cases, to prevent extremely long type names for complex types. The compiler will expand these when necessary to understand the type based on a simple heuristic, for example: ```mojo struct Dep[T: AnyType, v: T]: pass alias MyDep[T: AnyType, v: T] = Dep[T, v] alias MyDepGetAlias0 = MyDep.hello ``` produces: ```console $ mojo t.mojo t.mojo:10:29: error: 'MyDep' needs more parameters bound before accessing attributes alias MyDepGetAlias0 = MyDep.hello ^ t.mojo:10:29: note: 'MyDep' is aka 'alias[T: AnyType, v: T] Dep[T, v]' ``` Please file issues in cases where more information needs to be exposed. * Error messages now preserve symbolic calls to `always_inline("builtin")` functions rather than inlining them into the error message. * This release includes a number of improvements to elaboration errors. The *elaborator* is the compiler pass where final parameter values are determined and code is transformed from parametric to concrete. Common errors from elaboration include "function instantiation failed" and "call expansion failed" messages, when it's determined that the actual parameter values don't match any viable function overload. The elaborator also checks constraints (defined using the [`constrained()`](/mojo/std/builtin/constrained/constrained/) function), so constraint failures are generated during this pass. * Elaboration errors now report the full call instantiation failure path. For this Mojo file: ```mojo fn fn1[T: ImplicitlyCopyable, //] (a: T): constrained[False]() fn fn2[T: ImplicitlyCopyable, //] (a: T): return fn1(a) fn main(): fn2(1) ``` Now the error prints the path of `main -> fn2 -> fn1 -> constrained[False]` instead of just `constrained[False]`. * Elaboration errors now print out trivial parameter values with call expansion failures. For this simple Mojo program: ```mojo fn fn1[a: Int, b: Int](): constrained[a < b]() fn fn2[a: Int, b: Int](): fn1[a, b]() fn main(): fn2[4, 2]() ``` Now the error message shows `parameter value(s): ("a": 4, "b": 2)`: ```mojo test.mojo:6:14: note: call expansion failed with parameter value(s): ("a": 4, "b": 2) fn1[a, b]() ``` Only string and numerical values are printed out by default now, other values are shown as `...`. Use `--elaboration-error-verbose` to show all parameter values. * Elaboration error messages related to the prelude are omitted by default. The prelude is the set of APIs exported from `std/builtin/_startup.mojo`. These APIs persist in all call expansion paths but are rarely the source of reported errors. These APIs are now omitted by default to de-clutter elaboration errors. Use `--elaboration-error-include-prelude` to include prelude. * By default (without prelude): ```mojo test.mojo:43:4: error: function instantiation failed fn main(): ^ test.mojo:45:12: note: call expansion failed my_func() ... ``` * With prelude: ```mojo oss/modular/mojo/stdlib/std/builtin/_startup.mojo:119:4: error: function instantiation failed fn __mojo_main_prototype( ^ oss/modular/mojo/stdlib/std/builtin/_startup.mojo:119:4: note: call expansion failed with parameter value(s): (...) oss/modular/mojo/stdlib/std/builtin/_startup.mojo:42:4: note: function instantiation failed fn __wrap_and_execute_main[ ^ oss/modular/mojo/stdlib/std/builtin/_startup.mojo:68:14: note: call expansion failed main_func() ^ test.mojo:43:4: note: function instantiation failed fn main(): ^ test:45:12: note: call expansion failed my_func() ... ``` * An `--elaboration-error-limit` option has been added to the `mojo run` and `mojo build` commands. This option sets a limit to the number of elaboration errors that get printed. The default value is 20. To change the limit, use --elaboration-error-limit=limit, where a *limit* of `0` means unlimited. * `--help-hidden` option has been added to all Mojo CLI commands to show hidden options. * `mojo debug` now rejects unknown options between `debug` and the target. * The Mojo language server will now report more coherent code actions. * The `mojo` CLI now has an `--experimental-fixit` flag that automatically applies FixIt hints emitted by the parser. This feature is highly experimental, and users should ensure they back up their files (or check them into source control) before using it. ### ❌ Removed {#25-7-removed} * `mojo test` has been deprecated and removed as described in the [deprecation proposal](https://forum.modular.com/t/proposal-deprecating-mojo-test/2371). For more information, see the [Testing](#25-7-testing) section. * `LayoutTensorBuild` type has been removed. Use `LayoutTensor` with parameters directly instead. * The following traits have been removed: `LessThanComparable`, `GreaterThanComparable`, `LessThanOrEqualComparable`, `GreaterThanOrEqualComparable`. It is extremely rare that a type would only implement one of these, so one can just use `Comparable` instead. * All telemetry-related code has been removed from the Mojo compiler. This should eliminate the source of some hangs and misbehavior on poor internet connections. ### 🛠️ Fixed {#25-7-fixed} * [Issue #5111](https://github.com/modular/modular/issues/5111): The `math.cos()` and `math.sin()` function can now be evaluated at compile time. * Fixed `IntTuple.value(i)` method returning incorrect values when elements are stored as nested single-element tuples. Previously, calling `Layout.row_major(M, N).stride.value(i)` would return negative offset values (e.g., -65536, -65537) instead of the actual stride values. This affected any code that accessed layout stride or shape values using the `value()` method. * Fixed `LayoutTensor.shape[idx]()` method returning incorrect values for nested layouts. The bug occurred when accessing shape dimensions of tensors with nested layouts like `((32, 2), (32, 4))`, where the method would return garbage values instead of the correct product (e.g., 64). * Fixed `LayoutTensor` element-wise arithmetic operations (`+`, `-`, `*`, `/`) between tensors with different memory layouts. Previously, operations like `a.transpose() - b` would produce incorrect results when the operands had different layouts, because the same layout index was incorrectly used for both operands. This now correctly computes separate indices for each tensor based on its layout. * Fixed `LayoutTensor.shape[idx]()` method returning incorrect values for nested layouts. The bug occurred when accessing shape dimensions of tensors with nested layouts like `((32, 2), (32, 4))`, where the method would return garbage values instead of the correct product (e.g., 64). * Fixed `arange()` function in `layout._fillers` to properly handle nested layout structures. Previously, the function would fail when filling tensors with nested layouts like `Layout(IntTuple(IntTuple(16, 8), IntTuple(32, 2)), ...)` because it attempted to extract shape values from nested tuples incorrectly. * [Issue #5479](https://github.com/modular/modular/issues/5479): Mojo crashes when compiling standalone `__del__` function without struct context. * [Issue #5500](https://github.com/modular/modular/issues/5500): Added comprehensive documentation to `gpu/host/info.mojo` explaining GPU target configuration and LLVM data layout strings. The documentation now includes detailed explanations of all MLIR target components, vendor-specific patterns for NVIDIA/AMD/Apple GPUs, step-by-step guides for adding new GPU architectures, and practical methods for obtaining data layout strings. * [Issue #5492](https://github.com/modular/modular/issues/5492): Fixed `FileHandle` "rw" mode unexpectedly truncating file contents. Opening a file with `open(path, "rw")` now correctly preserves existing file content and allows both reading and writing, similar to Python's "r+" mode. Previously, "rw" mode would immediately truncate the file, making it impossible to read existing content and causing potential data loss. * [Issue #3849](https://github.com/modular/mojo/issues/3849): Added support for append mode ("a") when opening files. The `open()` function now accepts "a" as a valid mode, which opens a file for appending. Content written to a file opened in append mode is added to the end of the file without truncating existing content. If the file doesn't exist, it will be created. * [Issue #3208](https://github.com/modular/mojo/issues/3208): Fixed `FileHandle` raising "unable to remove existing file" error when opening a FIFO (named pipe) in write mode. Opening special files like FIFOs, devices, and sockets with `open(path, "w")` now works correctly. Previously, write mode would attempt to remove the existing file before opening it, which failed for special files that should not be removed. * [Issue #5142](https://github.com/modular/modular/issues/5142): The `sys.intrinsics.compressed_store` function now includes a `debug_assert` to catch null pointer usage, providing a clear error message instead of crashing with a segmentation fault. * The `sys.intrinsics.strided_load()`, `sys.intrinsics.strided_store()`, `sys.intrinsics.masked_load()`, and `sys.intrinsics.masked_store()` functions now include a `debug_assert()` to catch null pointer usage, providing a clear error message instead of crashing with a segmentation fault. * The `logger` package now prints its levels in color. * Throwing `deinit` methods now understand that `self` is deinitialized in error paths, avoiding redundant calls to implicit destructors and improving linear type support. ### Special thanks {#25-7-special-thanks} Special thanks to our community contributors: Brian Grenier ([@bgreni](https://github.com/bgreni)), c-pozzi ([@c-pozzi](https://github.com/c-pozzi)), Christoph Schlumpf ([@christoph-schlumpf](https://github.com/christoph-schlumpf)), cudawarped ([@cudawarped](https://github.com/cudawarped)), David Faden (revfad.com) ([@fadend](https://github.com/fadend)), Ethan Wu ([@YichengDWu](https://github.com/YichengDWu)), Hardik Gupta ([@hardikkgupta](https://github.com/hardikkgupta)), j\_rutzmoser ([@Rutzmoser](https://github.com/Rutzmoser)), Johnny Lin ([@johnny19436](https://github.com/johnny19436)), Jose ([@josetorrs](https://github.com/josetorrs)), josiahls ([@josiahls](https://github.com/josiahls)), Luis Chamberlain ([@mcgrof](https://github.com/mcgrof)), Manuel Saelices ([@msaelices](https://github.com/msaelices)), Marius S ([@winding-lines](https://github.com/winding-lines)), martinvuyk ([@martinvuyk](https://github.com/martinvuyk)), MaxMeyberg ([@MaxMeyberg](https://github.com/MaxMeyberg)), Monal ([@Monal-Patel](https://github.com/Monal-Patel)), skrript ([@skrript](https://github.com/skrript)), soraros ([@soraros](https://github.com/soraros)), and Thomas Mader ([@ThomasMader](https://github.com/ThomasMader)). ## v0.25.6 (2025-09-22) :::caution Version scheme change! This release is technically a version *downgrade* because we've added a `0.` at the beginning: `0.25.6`. This is necessary because we started publishing `mojo` packages on pypi.org and it's important that we don't publish a package greater than 1.0 yet. Shortly after the 0.25.6 release, we "yanked" all Conda packages that are greater than 1.0 so they won't be installed unless you explicitly specify the version. So you'll still get the latest version like this: ```sh pixi add mojo ``` And if you want an older version, just specify the version like this: ```sh pixi add "mojo==25.5" ``` However, if you're installing `mojo` as a Python package with `pip` or `uv`, the oldest version available is `0.25.6`. ::: ### ✨ Highlights {#25-6-highlights} * You can now **`pip install mojo`**! Although we still love the environment reliability of [Pixi](/pixi), installing Mojo with `pip` or `uv` further enhances our interoperability with the Python ecosystem, making it easier to [extend your Python code with Mojo](/mojo/manual/python/mojo-from-python). For more information, see the [Mojo install guide](/mojo/manual/install). * We've released a new version of the Mojo extension for Visual Studio Code that now works with **both** the stable and nightly Mojo releases. You can install the Mojo extension from either the [Visual Studio Code Marketplace](https://marketplace.visualstudio.com/items?itemName=modular-mojotools.vscode-mojo) or the [Open VSX Registry](https://open-vsx.org/extension/modular-mojotools/vscode-mojo). The new extension replaces the old stable version, so if you have the stable version installed, you can simply update it to receive the new version. If you have the nightly version of the extension installed, you should uninstall it and install the regular (non-nightly) version. See [Add the VS Code extension](/mojo/manual/install#add-the-vs-code-extension) for more information. * New [Mojo vision](/mojo/vision) doc explains our motivations and design decisions for the Mojo language. * New [Mojo roadmap](/mojo/roadmap) provides a high-level roadmap for the language across multiple phases. * Mojo now has support for default trait methods, allowing traits to provide reusable behavior without requiring every conforming struct to re-implement it. Default methods are automatically inherited by conforming structs unless explicitly overridden. See [Default method implementations](/mojo/manual/traits#default-method-implementations) in the Mojo Manual for more information. * Added support for many consumer GPUs, including initial support for Apple silicon GPUs. See [GPU support](#25-6-gpu-support) for details. * The way copying is modeled in Mojo has been overhauled. The `Copyable` trait has been updated to represent a type that can be *explicitly* copied (using a `copy()` method), and a new `ImplicitlyCopyable` "marker" trait can be used to *opt-in* to making a type implicitly copyable as well. **This swaps the default behavior from being implicitly copyable to being only explicitly copyable.** Several standard library traits, types, and functions now require now require explicit `Copyable` instead of `ImplicitlyCopyable`. See [Standard library changes](#25-6-standard-library-changes) for more information. * Uncaught exceptions or segmentation faults in Mojo programs can now generate stack traces. This is currently only for CPU-based code. To generate a fully symbolicated stack trace, set the `MOJO_ENABLE_STACK_TRACE_ON_ERROR` environment variable, use `mojo build` with debug info enabled, e.g. `-debug-level=line-tables`, and then run the resulting binary. * Major standard library improvements include: * Making the `SIMD` type conform to `Comparable` and `EqualityComparable`, which means that you can use `SIMD` values as `Dict` keys, among other things. For details, see [SIMD and related types](#25-6-simd-and-related-types). * A new `Some[Trait]` utility to make it easier to declare an argument that conforms to a trait. For details, see [Other standard library changes](#25-6-other-standard-library-changes). * Several enhancements to how iterators work in Mojo, including a new `Iterable` trait. For more information, see [Collections and iterators changes](#25-6-collections-and-iterators-changes). ### Language enhancements {#25-6-language-enhancements} * Mojo now allows the use of keywords in function names (after `def` and `fn`) and in attribute references after a `.`. This notably allows the use of the `match()` method in regex libraries even though Mojo takes this as a hard keyword. Uses in other locations can still use backticks: ```mojo struct MatchExample: fn match(self): # This is ok now. pass fn test_match(a: MatchExample): a.match() # This is ok now. a.`match`() # This is still valid. ``` * When generating error messages for complex types involving parameter calls, the Mojo compiler now prints functions parameter values correctly, eliminating a large class of `T != T` errors that happen with GPU layouts. ### Language changes {#25-6-language-changes} * Methods on structs may now declare their `self` argument with a `deinit` argument convention. This argument convention is used for methods like `__del__()` and `__moveinit__()` to indicate that they tear down the corresponding value without needing its destructor to be run again. Beyond these two methods, this convention can be used to declare "named" destructors, which are methods that consume and destroy the value without themselves running the values destructor. For example, the standard [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack/) type has these methods: ```mojo struct VariadicPack[...]: # implicit destructor fn __del__(deinit self): ... # move constructor fn __moveinit__(out self, deinit existing: Self): ... # custom explicit destructor that destroys "self" by transferring all of # the stored elements. fn consume_elements[ elt_handler: fn (idx: Int, var elt: element_type) capturing ](deinit self): ... ``` This argument convention is a fairly narrow power-user feature that is important to clarify the destruction model and make linear types fit into the model better. (A linear type is just a type that has no `__del__()` method, but provides a destructor that the user must call explicitly. Linear types are a proposed feature that hasn't been fully implemented yet.) * The `__del__()` and `__moveinit__()` methods should now take their `self` and `existing` arguments as `deinit` instead of either `owned` or `var`. * The `__disable_del` keyword and statement has been removed, use `deinit` methods instead. * The Mojo compiler now warns about use of the deprecated `owned` keyword, please move to `var` or `deinit` as the warning indicates. * The previously deprecated `@value` decorator has been removed. * Accesses to associated aliases and methods within a trait now require qualified references (prepended with `Self.`), making it consistent with how accesses to member aliases and methods in a struct require `self.`. * The Mojo compiler now raises error on implicit materialization of a non-`ImplicitlyCopyable` object, please either mark the type to be `ImplicitlyCopyable` or using `materialize[value: T]()` to explicitly materialize the parameter into a dynamic value. This usually happens when you generate a compile-time value using the `alias` keyword, then assign it to a runtime variable: ```mojo alias lst = [1, 2, 3, 4] # Create a compile-time list value var dyn_list = lst # Implicitly materializes the compile-time value to a # dynamically-allocated runtime value ``` The alias is a compile-time temporary value; to assign it to a runtime variable, Mojo must allocate memory and copy the value. Since this can result in unexpected memory allocations when materializing a memory value like a list, Mojo only allows implicit materializations of memory values if the type is `ImplicitlyCopyable`, which is a signal that the type should be inexpensive to copy. You can use the `materialize()` function to explicitly materialize a value: ```mojo var dyn_list = materialize[lst]() ``` ### Standard library changes {#25-6-standard-library-changes} #### Copyability changes * The way copying is modeled in Mojo has been overhauled. Previously, Mojo had two traits for modeling copyability: * `Copyable` denoted a type that could be copied implicitly * `ExplicitlyCopyable` denoted a type that could only be copied with an explicit call to a `.copy()` method. The vast majority of types defaulted to implementing `Copyable` (and therefore were implicitly copyable), and `ExplicitlyCopyable` was partially phased in but had significant usage limitations. Now, the new `Copyable` trait instead represents a type that can be *explicitly* copied (using `.copy()`), and a new `ImplicitlyCopyable` "marker" trait can be used to *opt-in* to making a type implicitly copyable as well. This swaps the default behavior from being implicitly copyable to being only explicitly copyable. The new `ImplicitlyCopyable` trait inherits from `Copyable`, and requires no additional methods. `ImplicitlyCopyable` is known specially to the compiler. (`ImplicitlyCopyable` types may also be copied explicitly using `.copy()`.) This makes it possible for non-implicitly-copyable types to be used with all standard library functionality, resolving a long-standing issue with Mojo effectively forcing implicit copyability upon all types. This will enable Mojo programs to be more efficient and readable, with fewer performance and correctness issues caused by accidental implicit copies. With this change, types that conform to `Copyable` are no longer implicitly copyable: ```mojo @fieldwise_init struct Person(Copyable): var name: String fn main(): var p = Person("Connor") var p2 = p # ERROR: not implicitly copyable var p3 = p.copy() # OK: may be copied explicitly ``` To enable a type to be implicitly copyable, declare a conformance to the `ImplicitlyCopyable` marker trait: ```mojo @fieldwise_init struct Point(ImplicitlyCopyable): var x: Float32 var y: Float32 fn main(): var p = Point(5, 10) var p2 = p # OK: may be implicitly copied var p3 = p.copy() # OK: may be explicitly copied ``` An additional nuance is that `ImplicitlyCopyable` may only be synthesized for types whose fields are all themselves `ImplicitlyCopyable` (and not merely `Copyable`). If you need to make a type with any non-`ImplicitlyCopyable` fields support implicit copying, you can declare the conformance to `ImplicitlyCopyable`, but write the `__copyinit__()` definition manually: ```mojo struct Container(ImplicitlyCopyable): var x: SomeCopyableType var y: SomeImplicitlyCopyableType fn __copyinit__(out self, existing: Self): self.x = existing.x.copy() # Copy field explicitly self.y = existing.y ``` For more information on copyability, see the section on [copy constructors](/mojo/manual/lifecycle/life#copy-constructor) in the Mojo manual. * The following standard library types and functions now require only explicit `Copyable` for their element and argument types, enabling their use with types that are not implicitly copyable: [`List`](/mojo/std/collections/list/List), [`Span`](/mojo/std/memory/span/Span/), [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) [`Optional`](/mojo/std/collections/optional/Optional), [`Variant`](/mojo/std/utils/variant/Variant), [`Tuple`](/mojo/std/builtin/tuple/Tuple), [`Dict`](/mojo/std/collections/dict/Dict), [`Set`](/mojo/std/collections/set/Set), [`Counter`](/mojo/std/collections/counter/Counter), [`LinkedList`](/mojo/std/collections/linked_list/LinkedList), [`Deque`](/mojo/std/collections/deque/Deque), and [`reversed()`](/mojo/std/builtin/reversed/reversed). Additionally, the following traits now require explicit `Copyable` instead of `ImplicitlyCopyable`: `KeyElement`, `IntervalElement`, `ConvertibleFromPython` * The following Mojo standard library types are no longer implicitly copyable: [`List`](/mojo/std/collections/list/List), [`Dict`](/mojo/std/collections/dict/Dict), [`DictEntry`](/mojo/std/collections/dict/DictEntry), [`OwnedKwargsDict`](/mojo/std/collections/dict/OwnedKwargsDict), [`Set`](/mojo/std/collections/set/Set), [`LinkedList`](/mojo/std/collections/linked_list/LinkedList), [`Node`](/mojo/std/collections/linked_list/Node) [`Counter`](/mojo/std/collections/counter/Counter/), [`CountTuple`](/mojo/std/collections/counter/CountTuple/), [`BitSet`](/mojo/std/collections/bitset/BitSet/), [`UnsafeMaybeUninitialized`](/mojo/std/memory/maybe_uninitialized/UnsafeMaybeUninitialized/), `DLHandle`, [`BenchConfig`](/mojo/std/benchmark/bencher/BenchConfig), [`BenchmarkInfo`](/mojo/std/benchmark/bencher/BenchmarkInfo), [`Report`](/mojo/std/benchmark/benchmark/Report), [`PythonTypeBuilder`](/mojo/std/python/bindings/PythonTypeBuilder/). To create a copy of one of these types, call the `.copy()` method explicitly: ```mojo var l = List[Int](1, 2, 3) # ERROR: Implicit copying of `List` is no longer supported: # var l2 = l # Instead, perform an explicit copy: var l2 = l.copy() ``` Alternatively, to transfer ownership, [use the `^` transfer sigil](/mojo/manual/values/ownership#transfer-arguments-var-and-): ```mojo var l = List[Int](1, 2, 3) var l2 = l^ # `l` is no longer accessible. ``` * Because `List` and `Dict` are so widely used, this release stages this in by making implicit copies of these types a warning instead of an error. This will become a hard error in the next release of Mojo. * User types that define a custom `.copy()` method must be updated to move that logic to `__copyinit__()`. The `.copy()` method is now provided by a default trait implementation on `Copyable` that should not be overridden: ```mojo trait Copyable: fn __copyinit__(out self, existing: Self, /): ... fn copy(self) -> Self: return Self.__copyinit__(self) ``` #### SIMD and related types {#25-6-simd-and-related-types} * The comparison operators (e.g. `__eq__()` and `__le__()`) of the `SIMD` type now return a single `Bool` instead of a boolean `SIMD` mask. Moreover, `SIMD` now has explicit element-wise comparisons that return boolean masks (for example, `eq()` and `le()`). * This allows `SIMD` to conform to the `EqualityComparable` trait, enabling the use of `SIMD` vectors in sets, as keys to dictionaries, generic search algorithms, etc. Moreover, `Scalar` now conforms to the `Comparable` trait, i.e. `SIMD` conforms to `Comparable` when the size is 1. * As a consequence, `SIMD.__bool__()` no longer needs to be restricted to scalars, and instead performs an `any` reduction on the elements of vectors. * Non-scalar `SIMD` constructors no longer allow implicit splatting of `Bool` values. This could lead to subtle bugs that cannot be caught at compile time, for example: ```mojo fn foo[w: Int](v: SIMD[_, w]) -> SIMD[DType.bool, w]: return v == 42 # this silently reduced to a single bool, and then splat ``` Similarly to `InlineArray`, an explicit constructor with the `fill` keyword-only argument can be used to express the same logic more safely: ```mojo fn foo[w: Int](v: SIMD[_, w]) -> SIMD[DType.bool, w]: return SIMD[DType.bool, w](fill=(v == 42)) # highlights the splat logic fn bar(Scalar[_]) -> Scalar[DType.bool]: # still works, since implicit splatting to a scalar is never ambiguous return v == 42 ``` * Several types that wrap MLIR types have been changed to further encapsulate their behavior, hiding this low-level behavior from non-advanced users. * Types that can be constructed from raw MLIR values now require the use of an `mlir_value` keyword-only argument initializer. Affected types include `SIMD` and `UInt`. * Types with raw MLIR type fields have had their `value` fields renamed to `_mlir_value`. Affected types include: `Bool` and `DType`. * The `SIMD.from_bits()` factory method is now a constructor, use `SIMD(from_bits=...)` instead. * `DType.index` is deprecated in favor of the new `DType.int`. Moreover, a new `DType.uint` is added, modeling unsigned integers with widths that match the machine word length. * The `index()` free function now returns an `Int`, instead of a raw MLIR `__mlir_type.index` value. #### Path and file system changes * Added the [`parts()`](/mojo/std/pathlib/path/Path/#parts) method to the `Path` type, for example instead of writing: ```mojo var path = Path("path/to/file") var parts = path.path.split(DIR_SEPARATOR) ``` you can now write: ```mojo var path = Path("path/to/file") var parts = path.parts() ``` * Added the [`name()`](/mojo/std/pathlib/path/Path/#name) method to the `Path` type, which returns the name of the file or directory. * Added `os.path.realpath()` to resolve symbolic links to an absolute path and remove relative path components (`.`, `..`, etc.). Behaves the same as the Python equivalent function. #### Collections and iterators changes {#25-6-collections-and-iterators-changes} * Added an [`iter`](/mojo/std/iter) module which includes the new `Iterable` trait (and the existing `Iterator` trait). The module also provides the [`enumerate()`](/mojo/std/iter/enumerate), [`iter()`](/mojo/std/iter/iter), [`map()`](/mojo/std/iter/map), [`next()`](/mojo/std/iter/next), and [`zip()`](/mojo/std/iter/zip) functions. The new generic `zip()` function replaces the `IntTuple`-specific `zip()` function provided in previous releases. * `Iterable`'s `origin` parameter is now named `iterable_origin` and its `mut` param is now named `iterator_mut` to avoid naming collisions. * [`Iterator`](/mojo/std/iter/Iterator) now has a defaulted `bounds()` trait method. This returns the lower and upper bounds of the remaining length of the iterator. This can be used for preallocation when building or extending collections from iterators. * Added [`take_items()`](/mojo/std/collections/dict/Dict/#take_items) draining iterator to `Dict`. * [`Span`](/mojo/std/memory/span/Span/) is now `Representable` if its elements implement the `Representable` trait. * Add `repr()` support for [`List`](/mojo/std/collections/list/List), [`Deque`](/mojo/std/collections/deque/Deque), [`Dict`](/mojo/std/collections/dict/Dict), [`LinkedList`](/mojo/std/collections/linked_list/LinkedList), [`Optional`](/mojo/std/collections/optional/Optional), and [`Set`](/mojo/std/collections/set/Set). * [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) now automatically detects whether its element types are trivially destructible to not invoke the destructors in its `__del__()` function. This improves performance for trivially destructible types (such as `Int` and friends). * Similar to above, `List` now automatically does optimizations based on whether the element types are trivial (copyable, destructible, etc). There is no longer a `hint_trivial_type` parameter as this is done automatically now. * `String.splitlines()` now returns a `List[StringSlice]` instead of a `List[String]`. This avoids unnecessary intermediate allocations. * `StringSlice.from_utf8()` factory method is now a constructor, use `StringSlice(from_utf8=...)` instead. * `Span` now implements a generic `.count()` method which can be passed a function that returns a boolean `SIMD` vector. The function counts how many times it returns `True` evaluating it in a vectorized manner. This works for any `Span[Scalar[D]]` e.g. `Span[Byte]`. * [`Optional`](/mojo/std/collections/optional/Optional) and [`OptionalReg`](/mojo/std/collections/optional/OptionalReg) can now be composed with `Bool` in expressions, both at comptime and runtime: ```mojo alias value = Optional[Int](42) @parameter if CompilationTarget.is_macos() and value: print("is macos and value is:", value.value()) ``` * [`sort()`](/mojo/std/builtin/sort/sort), [`LinkedList.pop()`](/mojo/std/collections/linked_list/LinkedList/#pop), [`LinkedList.maybe_pop()`](/mojo/std/collections/linked_list/LinkedList/#maybe_pop) and [`Dict.popitem()`](/mojo/std/collections/dict/Dict#popitem) no longer copy elements, improving performance. * The `IndexList` and `DimList` types may no longer be implicitly constructed from tuple values. Most existing call sites already used explicit initializer calls (`IndexList(...)`), so implicit conversions have been removed to ensure uniformity and consistency. #### Pointer changes * Removed the `alignment` parameter and `static_alignment_cast()` method from [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer). This `alignment` parameter that was only used by the `alloc()` static method. * Added an `alignment` keyword argument to `UnsafePointer.alloc()`. Use this in place of the `alignment` parameter on the struct. * Removed the `alignment` parameter from `Span`, similar to `UnsafePointer` above. * The `UnsafePointer.init_pointee_explicit_copy()` method has been removed. Please use [`UnsafePointer.init_pointee_copy()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_copy) instead. * `UnsafePointer.move_pointee_into()` has been deprecated. Please use [`UnsafePointer.init_pointee_move_from()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_move_from). `src.move_pointee_into(dst)` used a reversed argument order from intuitive `LHS = RHS` semantics for assignment, effectively `RHS -> LHS`. The new function fixes this readability issue (`dst.init_pointee_move_from(src)`), and additionally follows the `init_pointee_*()` naming convention of the other existing methods for initializer a pointer memory location. #### Atomic operations * Added [`os.atomic.fence()`](/mojo/std/os/atomic/fence) for creating atomic memory fences. ```mojo from os.atomic import Atomic, Consistency, fence fn decrease_ref_count(ref_count: Atomic[DType.uint64]): if atomic.fetch_sub[ordering = Consistency.MONOTONIC](1) == 1: fence[Consistency.ACQUIRE]() # ... ``` * Add a memory ordering parameter to [`Atomic.load()`](/mojo/std/os/atomic/Atomic). #### System changes * Added `sys.info.platform_map()` for specifying types that can have different values depending on the platform: ```mojo from sys.info import platform_map alias EDEADLK = platform_map["EDEADLK", linux = 35, macos = 11]() ``` * Renamed a number of functions following functions in `sys.info` from `flatcase` to `snake_case` names: | Old name | New name | | ----------------- | ------------------- | | `alignof()` | `align_of()` | | `bitwidthof()` | `bit_width_of()` | | `simdbitwidth()` | `simd_bit_width()` | | `simdbytewidth()` | `simd_byte_width()` | | `simdwidthof()` | `simd_width_of()` | | `sizeof()` | `size_of()` | The old names are deprecated, and will be removed in the next release. #### GPU Support {#25-6-gpu-support} * Added initial support for programming Apple Silicon GPUs in Mojo. However, MAX graphs are not yet enabled on Apple Silicon GPUs, and many hardware features remain to be enabled. * Added support for AMD RX 6900 XT consumer-grade GPU. * Added support for AMD RDNA3.5 consumer-grade GPUs in the `gfx1150`, `gfx1151`, and `gfx1152` architectures. Representative configurations have been added for AMD Radeon 860M, 880M, and 8060S GPUs. * Added support for NVIDIA GTX 1080 Ti consumer-grade GPUs. * Added support for NVIDIA Tesla P100 datacenter GPUs. * Updated `layout_tensor()` copy related functions to support 2D and 3D threadblock dimensions. #### Other standard library changes {#25-6-other-standard-library-changes} * A new `Some` utility is introduced to reduce the syntactic load of declaring function arguments of a type that implements a given trait or trait composition. For example, instead of writing ```mojo fn foo[T: Intable, //](x: T) -> Int: return x.__int__() ``` one can now write: ```mojo fn foo(x: Some[Intable]) -> Int: return x.__int__() ``` * The [`compile.reflection.get_type_name()`](/mojo/std/compile/reflection/get_type_name) utility now has limited capability to print parametric types, e.g. `SIMD[DType.float32, 4]` instead of just `SIMD`. If the parameter is not printable, an `` placeholder is printed instead. A new `qualified_builtins` flag also allows users to control the verbosity for the most common (but not all) builtin types. ### Kernels changes {#25-6-kernels-changes} * A fast matmul for SM100 is available in Mojo. Please check it out in [`matmul_sm100.mojo`](https://github.com/modular/modular/commits/main/max/kernels/src/linalg/matmul_sm100.mojo). * Moved the [`comm`](https://github.com/modular/modular/tree/main/max/kernels/src/comm) module from the standard library (`gpu.comm`) to the MAX AI kernels library. Any imports that used `gpu.comm` should be updated to `comm`, instead. ### Tooling changes {#25-6-tooling-changes} * `mojo test` now ignores folders with a leading `.` in the name. This will exclude hidden folders on Unix systems. * `mojo doc --validate-doc-strings` now emits a warning when an `fn` function is declared to raise an error (`raises`) and it has no [`Raises` docstring](https://github.com/modular/modular/blob/main/mojo/stdlib/docs/docstring-style-guide.md#errors). However, because Mojo automatically treats all `def` functions as [raising functions](/mojo/manual/functions#raising-and-non-raising-functions), we do not enforce `Raises` docs for `def` functions (to avoid noisy false positives). * Nightly `mojo` Python wheels are now available. To install everything needed for Mojo development in a Python virtual environment, you can use: ```sh pip install --pre mojo \ --index-url https://dl.modular.com/public/nightly/python/simple/ ``` For more information, see the [Mojo install guide](/mojo/manual/install). * In preparation for a future Mojo 1.0, the `mojo` and `mojo-compiler` packages have a `0.` prefixed to the version. ### ❌ Removed {#25-6-removed} * The Mojo MLIR C bindings has been removed. This was a private package that was used for early experimentation. ### 🛠️ Fixed {#25-6-fixed} * [#4695](https://github.com/modular/modular/issues/4695) - `Dict.__getitem__()` always returns immutable references. * [#4705](https://github.com/modular/modular/issues/4705) - Wrong mutability inferred for `__getitem__()` if `[]` operator is used and `__setitem__()` is present. * [#5190](https://github.com/modular/modular/issues/5190) - Mojo compiler crashes for a struct with two constructors taking different keyword-only arguments. * [#5139](https://github.com/modular/modular/issues/5139) - Crash on malformed initializer. * [#5183](https://github.com/modular/modular/issues/5183) - `Log1p()` not working on GPUs. * [#5105](https://github.com/modular/modular/issues/5105) - Outdated `CLAUDE.md` docs. * [#5239](https://github.com/modular/modular/issues/5239) - Contextual type not detected inside an inline if-else. * [#5305](https://github.com/modular/modular/issues/5305) - Parser Segfaults on `LayoutTensor[layout]` with no `layout` in scope. * [#5260](https://github.com/modular/modular/issues/5260) - Undefined reference to \`clock\_gettime\_nsec\_np' when building with -O0. * [#5307](https://github.com/modular/modular/issues/5307) - Bad error message when getting GPU info for unsupported GPU. * Error messages involving types using implicit parameters from auto-parameterized types now include context information to solve a class of incorrect "T != T" error messages common in kernel code. * Parameter inference failures now refer to parameters by their user-provided name, rather than complaining about a mysterious "parameter #4". ### Special thanks {#25-6-special-thanks} Special thanks to our community contributors: [@AceMouse](https://github.com/AceMouse), [@Alex-Mann](https://github.com/Alex-Mann), [@christoph-schlumpf](https://github.com/christoph-schlumpf), [@cudawarped](https://github.com/cudawarped), [@cyrillzadra](https://github.com/cyrillzadra), [@dl-alexandre](https://github.com/dl-alexandre), [@farnoy](https://github.com/farnoy), [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse), [@gryznar](https://github.com/gryznar), [@josiahls](https://github.com/josiahls), [@kyoto7250](https://github.com/kyoto7250) [@martinvuyk](https://github.com/martinvuyk), [@mmicu](https://github.com/mmicu), [@msaelices](https://github.com/msaelices), [@mzaks](https://github.com/mzaks), [@rd4com](https://github.com/rd4com), [@Rtosshy](https://github.com/Rtosshy), [@SasankYadati](https://github.com/SasankYadati), [@simonyjung](https://github.com/simonyjung), [@soraros](https://github.com/soraros), and [@ThomasMader](https://github.com/ThomasMader). ## v25.5 (2025-08-05) ### ✨ Highlights {#25-5-highlights} * Mojo is now available independently as the `mojo` Conda package. In includes the Mojo compiler, standard library, and the `layout` package (which is heavily used in GPU programming). It also includes the Mojo developer tools: LSP, debugger, formatter, and so on. To use Python to Mojo interoperability in v25.5, you must install the `modular` package. This will move to the `mojo` package in a future release. For more details, see the [install guide](/mojo/manual/install). * Parametric aliases are now supported: Aliases can be specified with an optional parameter list (just like functions). Parametric aliases are considered first class parameter values, too. For more details, see [Parametric aliases](/mojo/manual/parameters/#parametric-aliases) in the Mojo Manual. * Mojo API documentation now generates cross-references for parameter, argument, and return value types. ### Language enhancements {#25-5-language-enhancements} * `@parameter for` now works on a broader range of collection types, enabling things like `@parameter for i in [1, 2, 3]: ...`. * `StringLiteral` now automatically materializes to a `String` when used at runtime: ```mojo alias param = "foo" # type = StringLiteral var runtime_value = "bar" # type = String var runtime_value2 = param # type = String ``` This enables all the behavior users expect without having to convert or annotate types, for example: ```mojo var string = "hello" string += " world" var if_result = "foo" if True else "bar" ``` Initializing a `String` from a `StringLiteral` initially points to static constant memory, and does not perform any allocation until the first mutation. * The compiler now detects attempts to materialize references to compile-time interpreter stack memory into runtime code. This includes related types that reference memory, like slices, spans, and pointers. The compiler cannot currently track the lifetime of internal stack objects when materialized to runtime, which could cause memory leaks. Consider this example: ```mojo fn test_comptime_materialize(): # This is ok! Forms a comptime pointer to a comptime "stack" value of # String type. alias comptime_ptr = String("foo" + "bar").unsafe_ptr() # This is ok too, dereferences the pointer at comptime, loading the byte. alias byte = comptime_ptr[] # This materializes a Byte from comptime to runtime. var rt_byte = byte # Error: cannot materialize to runtime value, the type contains an origin # referring to a compile-time value var bad_usage = comptime_ptr ``` Previously the compiler would materialize the memory representation of the `String` value but not know it needs to be destroyed. It now detects the problem. If you run into this, rework the code to materialize the full object (e.g. the String) to runtime explicitly: ```mojo alias comptime_string = String("foo" + "bar") var runtime_string = comptime_string ``` ### Language changes {#25-5-language-changes} * The `@value` decorator has been formally deprecated with a warning, it will be removed in the next release of Mojo. Please move to the [`@fieldwise_init`](/mojo/manual/decorators/fieldwise-init/) and synthesized `Copyable` and `Movable` trait conformance. * Implicit trait conformance is removed. All conformances must be explicitly declared. * The `owned` argument convention is being renamed to `var`. This reflects that `var` is used consistently for a "named, scoped, owned value" already which is exactly what the `owned` convention does. In this release, both `var` and `owned` are allowed in an argument list, but `owned` will be removed in a subsequent release, so please move your code over. * Function overloading is now fully supported as long as, among two function signatures with the same list of argument types, one position is a keyword-only argument in at least one signature, and that position differs in argument name. Previously an edge case prevented this support when the return types are different. For example, these two functions can now co-exist: ```mojo fn get(self, idx: Int) -> Int fn get(self, *, idx2: Int) -> Float32 ``` ### Standard library changes {#25-5-standard-library-changes} * Indexing into a [`String`](/mojo/std/collections/string/string/String/) now returns a [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice), avoiding an allocation. `String.split()` now returns a `List[StringSlice]`. * Added support for a wider range of consumer-grade AMD hardware, including: * AMD Radeon RX 7xxx GPUs * AMD Radeon RX 9xxx GPUs * Compile-time checks for AMD RDNA3+ GPUs are now provided by the following functions (which can be imported from `sys.info`): * `_is_amd_rdna3()` * `_is_amd_rdna4()` * `_is_amd_rdna()` * Added WMMA matrix-multiplication instructions for RDNA3+ GPUs to help support running AI models on those GPUs. * Added support for NVIDIA GeForce RTX 3090. * [`memory.UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer/) is now implicitly included in all mojo files. Moreover, [`OpaquePointer`](/mojo/std/memory/unsafe_pointer/#opaquepointer) (the equivalent of a `void*` in C) is moved into the `memory` module, and is also implicitly included. * Python interop changes: * Mojo methods can now take `py_self: UnsafePointer[Self]` instead of the raw `py_self: PythonObject`, eliminating the downcasting boilerplate required in the common case. * Mojo functions can now natively accept keyword arguments from Python using `OwnedKwargsDict[PythonObject]` as the last parameter. This enables direct calling from Python with keyword arguments without requiring wrapper functions. ```mojo from collections import OwnedKwargsDict # Callable from Python as `foo(10, y=20)` fn foo(x: PythonObject, kwargs: OwnedKwargsDict[PythonObject]): y = kwargs["y"] ``` * The [`PythonTypeBuilder`](/mojo/std/python/bindings/PythonTypeBuilder/) utility now allows: * Registering bindings for Python static methods, i.e. methods that don't require an instance of the class. * Registering initializers that take arguments. Types no longer need to be `Defaultable` to be exposed and created from Python. * The `PythonConvertible` trait has been renamed to [`ConvertibleToPython`](/mojo/std/python/conversions/ConvertibleToPython//). This is now consistent with the [`ConvertibleFromPython`](/mojo/std/python/python_object/ConvertibleFromPython) trait, modeling Mojo types that can be converted either to or from a `PythonObject`. For more information, see [Calling Mojo from Python](/mojo/manual/python/mojo-from-python) in the Mojo Manual. * Added [`Iterator`](/mojo/std/iter/Iterator/) trait for modeling types that produce a sequence of values. A type can implement `Iterator` by providing `__next__()` and `__has_next__()` methods. This naming and behavior is based on the Python [`Iterator`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterator) type annotation, diverging slightly due to constraints present in Mojo today. Any type that implements `Iterator` can be used within `for` and `@parameter for` looping syntax. `Iterator` does not currently have a variant for supporting iteration over borrowed `ref` values. * The [`Dict`](/mojo/std/collections/dict/Dict/) type now has an `H` parameter which allows users to provider a custom `Hasher` type. * `default_hasher` (AHasher) and `default_comp_time_hasher` (Fnv1a) are now provided * The `H` parameter of `Dict` defaults to `default_hasher` * The [`Hashable`](/mojo/std/hashlib/hash/Hashable) trait has been updated to use a new data flow strategy. * Users are now required to implement the method `fn __hash__[H: Hasher](self, mut hasher: H):` (see the `Hashable` API documentation for further details). * `InlineArray` can now be constructed with a size of 0. This makes it easier to use `InlineArray` in situations where the number of elements is generic and could also be 0. * `List.append(Span)` has been renamed to `List.extend(Span)`. It is important for readability and consistency that `append()` always grows the length of the list by exactly 1. `extend()` in both Python and Rust is the variant of this operation that takes an arbitrary-length number of additional elements (possibly 0) to add to the list. * A new [`io`](/mojo/std/io/) module is available in the library. Some core input/output APIs previously in the `builtin` module have been moved to `io`. Currently all of the APIs in the `io` module are imported automatically. The following APIs were moved to `io`: * File-related APIs such as `open()`, `FileHandle` and `FileDescriptor`. * The `Writer` and `Writable` traits. * `input()` and `print()` functions. * `StringLiteral.strip()` family of functions now return a `StaticString`. ### Tooling changes {#25-5-tooling-changes} * Added support for GCC-style debug flags `-g0`, `-g1`, and `-g2` to match common compiler conventions: * `-g0`: No debug information (alias for `--debug-level=none`). * `-g1`: Line table debug information (alias for `--debug-level=line-tables`). * `-g2`: Full debug information (alias for `--debug-level=full`). * Added progress reporting support to the Mojo language server. This will emit progress notifications in your editor when the server is currently parsing a document. ### ❌ Removed {#25-5-removed} * Various functions from the `sys.info` package have been moved to the [`sys.info.CompilationTarget`](/mojo/std/sys/info/CompilationTarget) struct: * `is_x86()` * `has_sse4()` * `has_avx()` * `has_avx2()` * `has_avx512f()` * `has_fma()` * `has_vnni()` * `has_neon()` * `has_neon_int8_dotprod()` * `has_neon_int8_matmul()` * `UnsafePointer.address_of()` has been removed. Use `UnsafePointer(to=...)` constructor instead. Similarly, `Pointer.address_of()` has been removed. * `DType.tensor_float32` has been removed due to lack of support for it in the library and the compiler. ### 🛠️ Fixed {#25-5-fixed} * [#4121](https://github.com/modular/modular/issues/4121) - better error message for `.value()` on empty `Optional`. * [#4566](https://github.com/modular/modular/issues/4566) - Hang when assigning loop variable inside `@parameter for`. * [#4820](https://github.com/modular/modular/issues/4820) - `math.exp2` picks the wrong implementation for `float64`. * [#4836](https://github.com/modular/modular/issues/4836) - Else path in `@parameter for` broken. * [#4499](https://github.com/modular/modular/issues/4499) - Traits with `ref self` cause issues when used as parameter. * [#4911](https://github.com/modular/modular/issues/4911) - `InlineArray` now calls the move constructor for its elements when moved. * [#3927](https://github.com/modular/modular/issues/3927) - `InlineArray` now can be constructed with a size of 0. * [#4954](https://github.com/modular/modular/issues/4954) - `InlineArray` now does not call the copy constructor when being moved. * [#5066](https://github.com/modular/modular/issues/5066) - Correctly fill 64-bit values on AMD in `enqueue_fill`. * [#4982](https://github.com/modular/modular/issues/4982) - Add `toggle_all` to `BitSet`. * [#5086](https://github.com/modular/modular/issues/5086) - Add `set_all` to `BitSet`. * [#5057](https://github.com/modular/modular/issues/5057) - Span Performance Regression. * [#5051](https://github.com/modular/modular/issues/5051) - Incorrect `.modular` Directory Location on Linux. * [#5021](https://github.com/modular/modular/issues/5021) - LSP Crashes in VSCode when a local package exists. * [#5016](https://github.com/modular/modular/issues/5016) - Conditional Conformance Trait Alias Bug. ### Special thanks {#25-5-special-thanks} Special thanks to our community contributors: [@zsiegel92](https://github.com/zsiegel92), [@yeison](https://github.com/yeison), [@soraros](https://github.com/soraros), [@samufi](https://github.com/samufi), [@mzaks](https://github.com/mzaks), [@mmicu](https://github.com/mmicu), [@martinvuyk](https://github.com/martinvuyk), [@hardikkgupta](https://github.com/hardikkgupta), [@gustawdaniel](https://github.com/gustawdaniel), [@cyrillzadra](https://github.com/cyrillzadra), [@cnhz95](https://github.com/cnhz95), [@christoph-schlumpf](https://github.com/christoph-schlumpf), [@bgreni](https://github.com/bgreni), [@benz0li](https://github.com/benz0li), [@LeeLee26](https://github.com/LeeLee26), [@Caslyn](https://github.com/Caslyn), [@Amila-Rukshan](https://github.com/Amila-Rukshan), [@Amet13](https://github.com/Amet13), and [@AceMouse](https://github.com/AceMouse). ## v25.4 (2025-06-18) ### ✨ Highlights {#25-4-highlights} * Mojo now supports AMD GPUs, expanding hardware compatibility beyond NVIDIA to include AMD's GPU ecosystem. This enables Mojo applications to leverage AMD's RDNA and CDNA architectures for high-performance computing workloads, providing developers with greater flexibility in choosing hardware platforms for AI and compute-intensive applications. * Primitives for working with NVIDIA Blackwell GPUs have been added, providing low-level access to the latest GPU architecture features. These primitives enable developers to take advantage of Blackwell's enhanced compute capabilities, improved memory bandwidth, and advanced AI acceleration features, including support for newer tensor operations and optimized memory management patterns. * The Python-Mojo bindings are available as a preview release! This is the ability to call into Mojo functions from existing Python codebases. The use case is to speed up hot spots/slow Python code by rewriting certain portions of your code in Mojo to achieve performance. * Mojo collection types received many enhancements. * [`List`](/mojo/std/collections/list/List/), [`Set`](/mojo/std/collections/set/Set/), and [`Dict`](/mojo/std/collections/dict/Dict/) literals have been reimplemented to provide Python-equivalent features and syntax, including simple literals like `[1, 2, 3]` and `{k1: v1, k2: v2}`. * List comprehensions like `[a*b for a in range(10) if isprime(a) for b in range(20)]` as well as dictionary and set comprehensions are now supported. * Iterating over a collection with a `for` loop no longer requires using the `[]` deference operator. See [Language enhancements](#25-4-language-enhancements) and [Standard library changes](#25-4-standard-library-changes) for more details. * The entire MAX Kernel library is now fully open sourced! For more information, see the [MAX AI kernels library reference](/mojo/lib#max-ai-kernels-library) and the [MAX AI kernels source](https://github.com/modular/modular/tree/main/max/kernels). * Mojo is now available on [Godbolt.org](https://godbolt.org), which is also known as the "Compiler Explorer". See for an example. ### Language enhancements {#25-4-language-enhancements} * `var` declarations in functions now support more flexible "patterns", allowing multiple values to be declared at once, for example `var a, b = 4, 5` and `var a, b : Int, Float64`. * Mojo now supports the use of Python-style type patterns when declaring variables on first assignment without the `var` keyword. For example, `x = 4; y: UInt8 = 5` declares both `x` and `y`, where `x` is inferred to the default type of `Int` whereas `y` gets the explicit type `UInt8`. Declaring variables without `var` gives you a function-scoped name, whereas `var` makes things scoped to the statement they are in (lexical scoping), such as the body of an `if` statement. * Mojo now supports `ref` patterns that bind a stored LValue into a named declaration, extending the argument convention into local function scope. This can be useful when you want to do something with a reference, but don't want the conceptual overhead of a [`Pointer`](/mojo/std/memory/pointer/Pointer/). These are equivalent: ```mojo fn use_pointer(your_list: List[Int]): var p = Pointer(to=your_list[i]) # Form a safe pointer ... use(p[]) # dereference it fn use_ref(your_list: List[Int]): ref r = your_list[i] # Bind element reference to 'r' ... use(r) # use it ``` References are bound in their initializer and cannot be mutated afterward: uses and mutations of the reference are interpreted as uses and mutations of the value referenced by the value. * The Mojo compiler will now synthesize `__moveinit__()`, `__copyinit__()`, and `copy()` methods for structs that conform to [`Movable`](/mojo/std/builtin/value/Movable/), [`Copyable`](/mojo/std/builtin/value/Copyable/), and `ExplicitlyCopyable` (respectively) but that do not implement the methods explicitly. * A new [`@fieldwise_init`](/mojo/manual/decorators/fieldwise-init) decorator can be attached to structs to synthesize a field-wise initializer—an `__init__()` method that takes the same arguments as the fields in the struct. This gives access to this helpful capability without having to opt into the rest of the methods that `@value` synthesizes. This decorator allows an optional `@fieldwise_init("implicit")` form for single-element structs, which marks the initializer as [`@implicit`](/mojo/manual/decorators/implicit). * `try` and `raise` now work at compile time. * "Initializer lists" are now supported for creating struct instances with an inferred type based on context, for example: ```mojo fn foo(x: SomeComplicatedType): ... # Example with normal initializer. foo(SomeComplicatedType(1, kwarg=42)) # Example with initializer list. foo({1, kwarg=42}) ``` * List literals have been redesigned to work better. They produce homogenous sequences by invoking the `T(, __list_literal__: ())` constructor of a type `T` that is inferred by context, or otherwise defaulting to the standard library [`List`](/mojo/std/collections/list/List/) type. The `ListLiteral` type has been removed from the standard library. * Dictionary and set literals now work and default to creating instances of the [`Dict`](/mojo/std/collections/dict/Dict/) and [`Set`](/mojo/std/collections/set/Set/) types in the collections library. ### Language changes {#25-4-language-changes} * Implicit trait conformance is deprecated. Each instance of implicit conformance results in a warning, but compilation still goes through. Soon it will be upgraded into an error. Any code currently relying on implicit conformance should either declare conformances explicitly or, if appropriate, replace empty, non-load-bearing traits with trait compositions. * Mojo doesn't allow the use of `out` or `mut` as an argument name any longer. Previously you could use `fn x(out: Int)`, but this causes ambiguity with function types. Please use names like `output` instead. * `def` arguments are no longer implicitly mutable. If you would like to have a locally mutable argument, declare it `owned` explicitly. * Global (file-scope) variables are deprecated. Global variables in Mojo are only partially implemented and are known to cause cryptic errors. Now the Mojo compiler issues a warning on global variable usage. ### Standard library changes {#25-4-standard-library-changes} * GPU programming enhancements and changes: * Mojo now supports AMD GPUs, expanding hardware compatibility beyond NVIDIA to include AMD's GPU ecosystem. This enables Mojo applications to leverage AMD's RDNA and CDNA architectures for high-performance computing workloads, providing developers with greater flexibility in choosing hardware platforms for AI and compute-intensive applications. * Primitives for working with NVIDIA Blackwell GPUs have been added, providing low-level access to the latest GPU architecture features. These primitives enable developers to take advantage of Blackwell's enhanced compute capabilities, improved memory bandwidth, and advanced AI acceleration features, including support for newer tensor operations and optimized memory management patterns. See the [`gpu.tcgen05`](/mojo/std/gpu/compute/arch/tcgen05/) module API reference documentation for more information. * Added support for a wider range of consumer-grade hardware, including: * NVIDIA RTX 2060 GPUs * NVIDIA RTX 4090 GPUs * Fixed the `sum()` and `prefix_sum()` implementations in the [`gpu.block`](/mojo/std/gpu/primitives/block/) and [`gpu.warp`](/mojo/std/gpu/primitives/warp/) modules. Previously, the implementations have been incorrect and would either return wrong results or hang the kernel (due to the deadlock). [PR 4508](https://github.com/modular/modular/pull/4508) and [PR 4553](https://github.com/modular/modular/pull/4553) by [Kirill Bobyrev](https://github.com/kirillbobyrev) mitigate the found issues and add tests to ensure correctness going forward. * Collection type enhancements and changes: * The [`Dict`](/mojo/std/collections/dict/Dict/) type is now part of the prelude, so there is no need to import it anymore. * The [`List`](/mojo/std/collections/list/List/), [`Span`](/mojo/std/memory/span/Span/), [`Dict`](/mojo/std/collections/dict/Dict/), [`Set`](/mojo/std/collections/set/Set/), [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack/), and [`Deque`](/mojo/std/collections/deque/Deque/) iterators now return references to elements directly, instead of returning [`Pointer`](/mojo/std/memory/pointer/Pointer/). This means that you should no longer use the `[]` deference operator with the loop index variable: ```mojo var states: List[String] = ["California", "Hawaii", "Oregon"] # Old: for state in states: print(state[]) # New: for state in states: # state is an immutable reference print(state) ``` By default the reference is immutable. You can use the `ref` keyword to bind the index variable as a mutable reference: ```mojo for ref state in states: # state is a mutable reference state += "!" # Update the existing list element ``` * [`List`](/mojo/std/collections/list/List/), [`InlineArray`](/mojo/std/collections/inline_array/InlineArray/), [`Deque`](/mojo/std/collections/deque/Deque/), [`LinkedList`](/mojo/std/collections/linked_list/LinkedList/), and [`SIMD`](/mojo/std/builtin/simd/SIMD/) types all support construction via list literal syntax: ```mojo var list: List[Int] = [1, 2, 3] var vec: SIMD[DType.uint8, 8] = [1, 2, 3, 4, 5, 6, 7, 8] var deque: Deque[Float64] = [1, 2.5] var llist: LinkedList[Int] = [1, 2, 3] var arr: InlineArray[String, 3] = ["hi", "hello", "hey"] ``` * [`Dict`](/mojo/std/collections/dict/Dict/) and [`Set`](/mojo/std/collections/set/Set/) support construction via dict literal and set literal syntax, respectively: ```mojo var dict1 = {String("foo"): 1, String("bar"): 2} # Dict[String, Int] var dict2 = {1: 4, 2: 7, 3: 18} # Dict[Int, Int] var set = {1, 2, 3} # Set[Int] ``` * Python-style list, dictionary, and set comprehensions are now supported. For example: ```mojo # Basic list comprehension using a List[String] var upper_strs = [str.upper() for str in strs] # List[String] # Nested list comprehension with conditional expression var nums = [a * b for a in range(1, 5) if a % 2 == 0 for b in [-1, 1]] # List[Int] # Dictionary comprehension var squares_dict = {num: num * num for num in range(10)} # Dict[Int, Int] # Set comprehension var unique_remainders = {num % 4 for num in range(10)} # Set[Int] ``` * The [`BitSet`](/mojo/std/collections/bitset/BitSet) data structure was added to the [`collections`](/mojo/std/collections/) package. This is a fixed `BitSet` that simplifies working with a set of bits and performing bit operations. * [`VariadicList`](/mojo/std/builtin/variadics/VariadicList), [`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem), and [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack) moved to the new [`variadics`](/mojo/std/builtin/variadics/) module. * The `CollectionElement` trait has been removed. You can replace any use of it with the [`Copyable`](/mojo/std/builtin/value/Copyable) and [`Movable`](/mojo/std/builtin/value/Movable) traits, or the `Copyable & Movable` [trait composition](/mojo/manual/traits#trait-compositions). Python-Mojo interoperability enhancements and changes: * Python objects are now constructible with list, set, and dict literal syntax, for example: `var list: PythonObject = [1, "foo", 2.0]` will produce a Python list containing other Python objects and `var d: PythonObject = {}` will construct an empty dictionary. * `Python.unsafe_get_python_exception()` and `Python.throw_python_exception_if_error_state()` have been removed in favor of `Python().cpython().unsafe_get_error()` and `Python().cpython().get_error()`. * Since virtually any operation on a [`PythonObject`](/mojo/std/python/python_object/PythonObject) can raise, the `PythonObject` struct no longer implements the [`Indexer`](/mojo/std/builtin/int/Indexer/) and [`Intable`](/mojo/std/builtin/int/Intable/) traits. Instead, it now conforms to [`IntableRaising`](/mojo/std/builtin/int/IntableRaising), and users should convert explicitly to built-in types and handle exceptions as needed. In particular, the [`PythonObject.__int__()`](/mojo/std/python/python_object/PythonObject#__int__) method now returns a Python `int` instead of a mojo `Int`, so users must explicitly convert to a mojo `Int` if they need one (and must handle the exception if the conversion fails, for example due to overflow). * [`PythonObject`](/mojo/std/python/python_object/PythonObject) no longer implements the following traits: * [`Stringable`](/mojo/std/builtin/str/Stringable/). Instead, the [`PythonObject.__str__()`](/mojo/std/python/python_object/PythonObject#__str__) method now returns a Python `str` object and can raise. The new [`Python.str()`](/mojo/std/python/python/Python#str) static method can also be used to convert an arbitrary `PythonObject` to a Python `str` object. * [`KeyElement`](/mojo/std/collections/dict/#keyelement). Since Python objects may not be hashable—and even if they are, they could theoretically raise in the [`__hash__()`](/mojo/std/python/python_object/PythonObject#__hash__) method—`PythonObject` cannot conform to [`Hashable`](/mojo/std/hashlib/hash/Hashable/). This has no effect on accessing Python `dict` objects with `PythonObject` keys, since [`__getitem__()`](/mojo/std/python/python_object/PythonObject#__getitem__) and [`__setitem__()`](/mojo/std/python/python_object/PythonObject#__setitem__) should behave correctly and raise as needed. Two overloads of the [`Python.dict()`](/mojo/std/python/python/Python#dict) factory function have been added to allow constructing dictionaries from a list of key-value tuples and from keyword arguments. * [`EqualityComparable`](/mojo/std/builtin/comparable/#equalitycomparable). The [`PythonObject.__eq__()`](/mojo/std/python/python_object/PythonObject#__eq__) and [`PythonObject.__ne__()`](/mojo/std/python/python_object/PythonObject#__ne__) methods need to return other `PythonObject` values to support rich comparisons. Code that previously compared `PythonObject` values should be wrapped in [`Bool()`](/mojo/std/builtin/bool/Bool/#__init__) to perform the fallible conversion explicitly: `if Bool(obj1 == obj2): ...`. * [`Floatable`](/mojo/std/builtin/floatable/Floatable/). An explicit, raising constructor is added to [`SIMD`](/mojo/std/builtin/simd/SIMD/#__init__) to allow constructing `Float64` values from `PythonObject` values that implement `__float__()`. * A new [`def_function()`](/mojo/std/python/bindings/PythonModuleBuilder#def_function) API was added to [`PythonModuleBuilder`](/mojo/std/python/bindings/PythonModuleBuilder) to allow declaring Python bindings for arbitrary functions that take and return `PythonObject`s. Similarly, a new [`def_method()`](/mojo/std/python/bindings/PythonTypeBuilder#def_method) API is added to [`PythonTypeBuilder`](/mojo/std/python/bindings/PythonTypeBuilder) to allow declaring Python bindings for methods that take and return `PythonObject`s. * The [`ConvertibleFromPython`](/mojo/std/python/python_object/ConvertibleFromPython) trait is now public. This trait is implemented by Mojo types that can be constructed by converting from a `PythonObject`. This is the reverse operation of the [`PythonConvertible`](/mojo/std/python/python_object/PythonConvertible) trait. * [`Bool`](/mojo/std/builtin/bool/Bool/), [`Int`](/mojo/std/builtin/int/Int/), and [`String`](/mojo/std/collections/string/string/String/) now implement `ConvertibleFromPython`. * [`PythonObject(alloc=)`](/mojo/std/python/python_object/PythonObject#__init__) is a new constructor that can be used to directly store Mojo values in Python objects. This initializer will fail if the type of the provided Mojo value has not previously had a corresponding Python `type` object globally registered using [`PythonModuleBuilder.add_type()`](/mojo/std/python/bindings/PythonModuleBuilder#add_type). * [`PythonObject`](/mojo/std/python/python_object/PythonObject) has new methods for downcasting to a pointer to a contained Mojo value, for use in Python/Mojo interop. ```mojo struct Person: var name: String fn greet(obj: PythonObject) raises: var person = obj.downcast_value_ptr[Person]() print("Hello ", person[].name, "from Mojo🔥!") ``` * [`PythonObject.downcast_value_ptr[T]()`](/mojo/std/python/python_object/PythonObject#downcast_value_ptr) checks if the object is a wrapped instance of the Mojo type `T`, and if so, returns an `UnsafePointer[T]`. Otherwise, an exception is raised. * [`PythonObject.unchecked_downcast_value_ptr[T]()`](/mojo/std/python/python_object/PythonObject#unchecked_downcast_value_ptr) unconditionally returns an `UnsafePointer[T]` without any runtime type checking. This is useful when using Python/Mojo interop to optimize an inner loop and minimizing overhead is desirable. Also added an equivalent [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer#__init__) initializer for downcasting from a `PythonObject`. * The `TypedPythonObject` type has been removed. Use `PythonObject` instead. * The `Python.is_type(x, y)` static method has been removed. Use the expression `x is y` instead. * [`os.abort(messages)`](/mojo/std/os/os/abort) no longer supports a variadic number of [`Writable`](/mojo/std/format/Writable/) messages. While this API was high-level and convenient, it generated a lot of IR for simple and common cases, such as when we have a single `StringLiteral` message. We now no longer need to generate a bunch of bloated IR and instead, callers must create the `String` on their side before calling `os.abort(message)`. * The [`atof()`](/mojo/std/collections/string/string/atof/) function has been entirely rewritten as it produced incorrect results for very low and very high exponents. It now works correctly for strings with fewer than 19 digits left of the `e`. For example `1.1385616158185648648648648648616186186e-3` won't work, and will raise an error. Anything that does not produce an error is now guaranteed to be correct. While the current implementation is not the fastest, it's based on the paper [Number Parsing at a Gigabyte per Second](https://arxiv.org/abs/2101.11408) by Daniel Lemire. So with a bit of effort to pinpoint the slow parts, we can easily have state of the art performance in the future. * The [`math.isclose()`](/mojo/std/math/math/isclose/) function now supports both symmetric (Python-style) and asymmetric (NumPy-style) comparison modes via a new `symmetrical` parameter. The parameter defaults to the newly added symmetric support. The function now only supports floating-point types, removing previous pseudo-support for integer and boolean types. Support added in [PR 4608](https://github.com/modular/modular/pull/4608) by [@soraros](https://github.com/soraros). * The [`compile`](/mojo/std/compile/) module now provides the [`get_type_name()`](/mojo/std/reflection/type_info/get_type_name/) function to get the fully qualified name of a type. For example, `compile.get_type_name[Int]()` returns `"std.builtin.int.Int"`. ### Tooling changes {#25-4-tooling-changes} * Added support for emitting LLVM Intermediate Representation (.ll) using `--emit=llvm`. * Example usage: `mojo build --emit=llvm YourModule.mojo` * Removed support for the command line option `--emit-llvm` in favor of `--emit=llvm`. * Added support for emitting assembly code (.s) using `--emit=asm`. * Example usage: `mojo build --emit=asm YourModule.mojo` * Added associated alias support for documentation generated via [`mojo doc`](/mojo/cli/doc). * Added struct and trait conformance list sorting support to [`mojo format`](/mojo/cli/format). ### ❌ Removed {#25-4-removed} * `VariadicPack.each()` and `VariadicPack.each_idx()` methods have been removed. Use the [`@parameter for`](/mojo/manual/decorators/parameter#parametric-for-statement) language construct to achieve this now. The `write_buffered()` and `write_args()` functions have also been removed, to improve compile speed and reduce register pressure on GPU, you should now unroll the variadic pack at each call site: Unbuffered: ```mojo fn write[*Ts: Writable](mut self, *args: *Ts): var string = String() @parameter for i in range(args.__len__()): args[i].write_to(string) ``` Buffered: ```mojo from utils.write import _WriteBufferStack fn write[*Ts: Writable](mut self, *args: *Ts): var string = String() var buffer = _WriteBufferStack(string) @parameter for i in range(args.__len__()): args[i].write_to(buffer) buffer.flush() ``` ### 🛠️ Fixed {#25-4-fixed} * [#1649](https://github.com/modular/modular/issues/1649) - Trailing comma is not supported in assignments. * [#3415](https://github.com/modular/modular/issues/3415) - Type annotation fails on implicit variable declarations. * [#4352](https://github.com/modular/modular/issues/4352) - `math.sqrt` products incorrect results for large inputs. * [#4518](https://github.com/modular/modular/issues/4518) - Try Except Causes False Positive "Uninitialized Value". * [#4677](https://github.com/modular/modular/issues/4677) - `UIntN` Comparison Yields Incorrect Result When Function Parameter Is Involved (`UInt8`–`UInt64`). * [#4684](https://github.com/modular/modular/issues/4684) - Failure inferring type of initializer list from field of struct. * [#4688](https://github.com/modular/modular/issues/4668) - Incorrect result for unsigned `gt` and `le` comparisons. * [#4694](https://github.com/modular/modular/issues/4694) - Compiler error handling `x or y` expressions with PythonObject. * [#4719](https://github.com/modular/modular/issues/4719) - `Dict.setdefault` should not be marked with `raises`. ### Special thanks Special thanks to our community contributors: [@astrobdr](https://github.com/astrobdr), [@bgreni](https://github.com/bgreni), [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse), [@godardt](https://github.com/godardt), [@hardikkgupta](https://github.com/hardikkgupta), [@Hundo1018](https://github.com/Hundo1018), [@kirillbobyrev](https://github.com/kirillbobyrev), [@martinvuyk](https://github.com/martinvuyk), [@msaelices](https://github.com/msaelices), [@mzaks](https://github.com/mzaks), [@OwenJRJones](https://github.com/OwenJRJones), [@shogo314](https://github.com/shogo314), @sibarras, [@simveit](https://github.com/simveit), [@soraros](https://github.com/soraros), [@sstadick](https://github.com/sstadick) ## v25.3 (2025-05-06) ### ✨ Highlights * Parts of the Mojo standard library continue to be progressively open sourced! Packages that are open sourced now include: * [`algorithm`](/mojo/std/algorithm/) * [`benchmark`](/mojo/std/benchmark/) * [`buffer`](/mojo/std/buffer/) * [`compile`](/mojo/std/compile/) * [`complex`](/mojo/std/complex/) * [`gpu`](/mojo/std/gpu/) * [`logger`](/mojo/std/logger/) * [`runtime`](/mojo/std/runtime/) * [`subprocess`](/mojo/std/subprocess/) For more information, see the [Standard library reference](/mojo/lib#standard-library) and the [Standard library source](https://github.com/modular/modular/tree/main/mojo/stdlib). * Parts of the MAX AI kernels library continue to be progressively open sourced! Packages that are open sourced now include: * [`layout`](/mojo/kernels/layout/) * [`linalg`](/mojo/kernels/linalg/) For more information, see the [MAX AI kernels library reference](/mojo/lib#max-ai-kernels-library) and the [MAX AI kernels source](https://github.com/modular/modular/tree/main/max/kernels). * Trait compositions are now supported via the `&` syntax. A trait composition combines two traits into one logical trait whose constraint set is the union of the constraint sets of the two original traits. For more information, see [Trait compositions](/mojo/manual/traits/#trait-compositions) in the Mojo Manual. * String types in Mojo got several significant improvements. See [Standard library changes](#25-3-standard-library-changes) for details. ### Language changes {#25-3-language-changes} * Mojo can now use [user-declared `__merge_with__()` dunder methods](https://github.com/modular/modular/blob/main/mojo/proposals/custom-type-merging.md) to merge values when using different types in ternary operations. This has been adopted to allow pointers to work naturally with the ternary operator, for example `var x = one_pointer if cond else other_pointer`. * Auto-parameterization now extends to struct metatypes. For example, this declaration `fn foo[M: type_of(StringLiteral[_])]` will auto-parameterize on the unbound parameter of `StringLiteral`. * The Mojo compiler now warns about stores to values that are never used, e.g.: `x = foo(); x = bar()` will warn about the first assignment to `x` because it is overwritten. You can generally address this by deleting dead code, or by assigning to `_` instead: `_ = foo(); x = bar()`. You may also encounter this in variable declarations, e.g. `var x = 0; ...; x = foo()`. In this case, change the variable to being declared as uninitialized, e.g. `var x: Int`. You may also silence this warning entirely for a variable by renaming it to start with an underscore, e.g. `_x`. * The Mojo compiler now warns about obsolete use of `mut self` in initializers, please switch over to `fn __init__(out self)` instead. * `def` functions now require type annotations on arguments, and treat a missing return type as returning `None`. Previously these defaulted to the `object` type which led to a variety of problems. Support for `object` has been removed until we have time to investigate a proper replacement. ### Standard library changes {#25-3-standard-library-changes} String types in Mojo got several significant improvements: * The [`String`](/mojo/std/collections/string/string/String/) type no longer copies data from [`StringLiteral`](/mojo/std/builtin/string_literal/StringLiteral/) and [`StaticString`](/mojo/std/collections/string/string_slice/#aliases) since they are known-static-constant values. This allows us to make construction from these values be implicit, which improves ergonomics and performance together. It also implements the "small string optimization", which avoids heap allocation for common short strings. On a 64-bit system, `String` can hold up to 23 bytes inline. Its copy constructor is now O(1), performing string data copy lazily on mutation. * The types [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/) and [`StaticString`](/mojo/std/collections/string/string_slice/#aliases) are now part of the prelude, there is no need to import them anymore. These are useful for code that just needs a "view" of string data, not to own and mutate it. * The [`StringLiteral`](/mojo/std/builtin/string_literal/StringLiteral/) type has been moved to a more reliable "dependent type" design where the value of the string is carried in a parameter instead of a stored member. This defines away a category of compiler crashes when working with `StringLiteral` that involved attempting to manipulate a `StringLiteral` at run time. As a consequence of this change, many APIs should switch to using [`StaticString`](/mojo/std/collections/string/string_slice/#aliases) instead of `StringLiteral`. For more information on this "dependent type" design for literals, see the proposal, [Fixing Simple Literals in Mojo](https://github.com/modular/modular/blob/main/mojo/proposals/fixing-simple-literals.md). * `String` supports a new `String(unsafe_uninit_length=x)` constructor and `str.resize(unsafe_uninit_length=x)` for clients that want to allocate space that they intend to fill in with custom unsafe initialization patterns. The `String(ptr=x, length=y)` constructor has been removed. * `String` supports working with legacy C APIs that assume null termination, but the details have changed: `String` is now no longer implicitly null-terminated, which means that it is incorrect to assume that `str.unsafe_ptr()` will return a null-terminated string. For that, use the `str.unsafe_cstr_ptr()` method. It now requires the string to be mutable in order to make null-termination lazy on demand. This improves performance for strings that are not passed to legacy APIs. * The [`List`](/mojo/std/collections/list/List) type has been improved similarly to `String` to reduce inconsistency and enable power-user features, including removing adding `List(unsafe_uninit_length=x)` and `list.resize(unsafe_uninit_size=n)` methods avoid initialized memory that the caller plans to overwrite. * [`Set`](/mojo/std/collections/set/Set/) now conforms to the [`Copyable`](/mojo/std/builtin/value/Copyable/) trait so you can store sets in other types of collections (for example, as values in a `Dict`). * The following traits have been removed in favor of trait composition: `EqualityComparableCollectionElement`, `RepresentableCollectionElement`, `TestableCollectionElement`, `Testable`, `StringableIdentifiable`, `StringableCollectionElement`, `IntervalPayload`, `WritableCollectionElement`, `ComparableCollectionElement`, `BoolableCollectionElement`, `EqualityComparableWritableCollectionElement`, `EqualityComparableWritableCollectionElementNew`, `CollectionElementNew`, `WritableCollectionElementNew`. For example, you can replace `EqualityComparableCollectionElement` with `EqualityComparable & CollectionElement`. `StringableCollectionElement` was already deprecated and scheduled to be removed; it can be replaced with `Writable & CollectionElement`. * The [`PythonObject`](/mojo/std/python/python_object/PythonObject) type is being reworked in preparation for some improvements to Mojo-Python interoperability: * Since virtually any operation on a `PythonObject` can raise, the `PythonObject` struct no longer implements the following traits: `ImplicitlyBoolable`, `ImplicitlyIntable`. * `PythonObject` is no longer implicitly constructible from tuple or list literals. For example, `var x : PythonObject = [1, 2, "foo"]` is no longer accepted. Instead, please use the new `Python.list()` and `Python.tuple()` factory methods. For example: ```mojo var x = Python.list(1, 2, "foo") ``` (The `list()` and `tuple()` factory methods were originally added on `PythonObject`, but have been moved to the `Python` struct.) We hope to re-enable literal syntax in the future as the standard library matures. * `PythonObject.from_borrowed_ptr()` has been removed in favor of a constructor with a keyword-only `from_borrowed_ptr` argument. * The deprecated `PythonObject.to_float64()` method has been removed. Use the `Float64()` constructor, instead. * [`Span`](/mojo/std/memory/span/Span) now has a `swap_elements()` method which takes two indices and swaps them within the span. * [`Pointer`](/mojo/std/memory/pointer/Pointer/) now has a `get_immutable()` method to return a new `Pointer` with the same underlying data but with an `ImmutableOrigin`. * You can now forward a [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack/) where all values are `Writable` to a writer using `WritableVariadicPack`: ```mojo from utils.write import WritableVariadicPack fn print_message[*Ts: Writable](*messages: *Ts): print("message:", WritableVariadicPack(messages), "[end]") x = 42 print_message("'x = ", x, "'") ``` ```text message: 'x = 42' [end] ``` In this example the variadic pack is buffered to the stack in the `print` call along with the extra arguments, before doing a single syscall to write to stdout. * [`debug_assert()`](/mojo/std/builtin/debug_assert/debug_assert/) in AMD GPU kernels now behaves the same as on NVIDIA, printing the thread information and variadic args passed after the condition: ```mojo from gpu.host import DeviceContext fn kernel(): var x = 1 debug_assert(x == 2, "x should be 2 but is: ", x) def main(): with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=2, block_dim=2) ``` Running `mojo run -D ASSERT=all [filename]` will output: ```text At /tmp/test.mojo:5:17: block: [0,0,0] thread: [0,0,0] Assert Error: x should be 2 but is: 1 At /tmp/test.mojo:5:17: block: [0,0,0] thread: [1,0,0] Assert Error: x should be 2 but is: 1 At /tmp/test.mojo:5:17: block: [1,0,0] thread: [0,0,0] Assert Error: x should be 2 but is: 1 At /tmp/test.mojo:5:17: block: [1,0,0] thread: [1,0,0] Assert Error: x should be 2 but is: 1 ``` * The [`constrained[cond, string]()`](/mojo/std/builtin/constrained/constrained/) function now accepts multiple strings that are printed concatenated on failure, so you can use: ```mojo constrained[cond, "hello: ", String(n), ": world"]() ``` This is more compile-time efficient and somewhat more ergonomic than using string concatenation. * [`pathlib.Path.write_text()`](/mojo/std/pathlib/path/Path/#write_text) now accepts a `Writable` argument instead of a `Stringable` argument. This makes the function more efficient by removing a String allocation. * Added [`pathlib.Path.write_bytes()`](/mojo/std/pathlib/path/Path/#write_bytes) which enables writing raw bytes to a file. * Added [`os.path.split_extension()`](/mojo/std/os/path/path/split_extension) to split a path into its root and extension. * Added [`os.path.is_absolute()`](/mojo/std/os/path/path/is_absolute) to check if a given path is absolute or not. * One can now specify the consistency model used in atomic operations with the default being sequential consistency. The consistency models are defined in the [`Consistency`](/mojo/std/os/atomic/Consistency/) struct. * Added [`Variant.is_type_supported()`](/mojo/std/utils/variant/Variant/#is_type_supported) method. ([PR #4057](https://github.com/modular/modular/pull/4057)) Example: ```mojo def takes_variant(mut arg: Variant): if arg.is_type_supported[Float64](): arg = Float64(1.5) def main(): var x = Variant[Int, Float64](1) takes_variant(x) if x.isa[Float64](): print(x[Float64]) # 1.5 ``` * The `type` parameter of `SIMD` has been renamed to `dtype`. * The `is_power_of_two(x)` function in the `bit` package is now a method on `Int`, `UInt` and `SIMD`. * The `Pointer.address_of(...)` and `UnsafePointer.address_of(...)` functions have been deprecated. Please use the [`Pointer(to=...)`](/mojo/std/memory/pointer/Pointer#__init__) and [`UnsafePointer(to=...)`](/mojo/std/memory/unsafe_pointer/UnsafePointer#__init__) constructors instead. Conceptually, this is saying "please initialize a `Pointer` (a reference, if you will) to *some other address in memory*. In the future, these `address_of()` functions will be removed. ### Tooling changes {#25-3-tooling-changes} * Fixed SIMD boolean display in debugger: SIMD boolean values now display correctly with proper bit extraction. * Improved language server performance: The language server now avoids parsing more than it needs to, improving performance across the board. * The Mojo compiler is now able to interpret all arithmetic operations from the `index` dialect that are used in methods of `Int` and `UInt` types. That allows users to finally compute constants at compile time: ```mojo alias a: Int = 1000000000 alias b: Int = (5 * a) // 2 ``` Previously, the compiler would throw the error "cannot fold operation". * Added a new `--emit-llvm` option to the `mojo build` command, which allows users to emit LLVM IR. When `--emit-llvm` is specified, the build process will: compile mojo code to LLVM IR, save the IR to a .ll file (using the same name as the input file), and print the IR to stdout for immediate inspection. ### Other changes * The syntax for adding attributes to an `__mlir_op` is now limited to inherent attributes (those defined by the op definition). Most users will not need to attach other kinds of attributes, and this helps guard against typos and mojo code getting outdated when the dialect changes. ### ❌ Removed {#25-3-removed} * The `SIMD.roundeven()` method has been removed from the standard library. This functionality is now handled by the [`round()`](/mojo/std/builtin/math/round) function. * Error messages about the obsolete `borrowed` and `inout` keywords, as well as the obsolete `-> Int as name` syntax have been removed. * The `object` type has been removed. * `utils.numerics.ulp` has been removed. Use the [`ulp()`](/mojo/std/math/math/ulp) function from the `math` package instead. * Several free functions that were deprecated in the 25.2 release have now been removed. This includes: * The `str` free function. Use the `String` constructor instead. * The `int` free function. Use the `Int` constructor instead. * The `bool` free function. Use the `Bool` constructor instead. * The `float` free function. Use the `Float64` constructor instead. * Removed deprecated [`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext/) methods `copy_sync()` and `memset_sync()`. * The `unroll()` utility has been removed. Use the [`@parameter for` construct](/mojo/manual/decorators/parameter#parametric-for-statement) instead. ```mojo from utils.loop import unroll # Before @always_inline @parameter fn foo[i: Int](): body_logic[i]() unroll[foo, iteration_range]() # After @parameter for i in range(iteration_range): body_logic[i]() ``` * The `InlinedString` type has been removed. Use `String` instead which now supports the Small String Optimization (SSO). * The `AsBytes` trait has been removed. ### 🛠️ Fixed {#25-3-fixed} * [#3510](https://github.com/modular/modular/issues/3510) - `PythonObject` doesn't handle large `UInt64` correctly. * [#3847](https://github.com/modular/modular/issues/3847) - Count leading zeros can't be used on `SIMD` at compile time. * [#4198](https://github.com/modular/modular/issues/4198) - Apple M4 is not properly detected with `sys.is_apple_silicon()`. * [#3662](https://github.com/modular/modular/issues/3662) - Code using `llvm.assume` cannot run at compile time. * [#4273](https://github.com/modular/modular/issues/4273) - `count_leading_zeros` doesn't work for vectors with size > 1 at compile time. * [#4320](https://github.com/modular/modular/issues/4320) - Intermittent miscompilation with bytecode imported traits. * [#4281](https://github.com/modular/modular/issues/4281) - MAX does not support RTX 5000-series GPUs. * [#4163](https://github.com/modular/modular/issues/4163) - Corner case in initializers. * [#4360](https://github.com/modular/modular/issues/4360) - Fix constructor emission for parameterized types conforming to a trait composition. * [#4362](https://github.com/modular/modular/issues/4362) - Function call with `IntLiteral` incorrectly eliminated despite side-effects. * [#4431](https://github.com/modular/modular/issues/4431) - \[BUG] Python.evaluate doesn't handle null termination correctly. * [#4492](https://github.com/modular/modular/issues/4488) - Fix `StringSlice.replace` seg fault. ### Special thanks Special thanks to our community contributors: [@auris](https://github.com/auris), [@bgreni](https://github.com/bgreni), [@christianbator](https://github.com/christianbator), [@KamilGucik](https://github.com/KamilGucik), [@kasmith11](https://github.com/kasmith11), [@martinvuyk](https://github.com/martinvuyk), [@ratulb](https://github.com/ratulb), [@rd4com](https://github.com/rd4com), [@sora](https://github.com/sora), [@thatstoasty](https://github.com/thatstoasty), and [@winding-lines](https://github.com/winding-lines). ## v25.2 (2025-03-25) ### ✨ Highlights * Check out the new [GPU basics](/mojo/manual/gpu/basics) section of the [Mojo Manual](/mojo/manual) and the [Get started with GPU programming with Mojo and the MAX Driver](/mojo/manual/gpu/intro-tutorial) tutorial for a guide to getting started with GPU programming in Mojo! * Some APIs in the [`gpu`](/mojo/std/gpu/) package were enhanced to simplify working with GPUs. * If you're executing a GPU kernel only once, you can now skip compiling it first before enqueueing it, and pass it directly to [`DeviceContext.enqueue_function()`](/mojo/std/gpu/host/device_context/DeviceContext#enqueue_function). * The three separate methods on `DeviceContext` for asynchronously copying buffers between host and GPU memory have been combined to single overloaded [`enqueue_copy()`](/mojo/std/gpu/host/device_context/DeviceContext/#enqueue_copy) method, and the three separate methods for synchronous copies have been combined into an overloaded [`copy_sync()`](/mojo/std/gpu/host/device_context/DeviceContext/#copy_sync) method. * The `gpu.shuffle` module has been renamed to [`gpu.warp`](/mojo/std/gpu/primitives/warp/) to better reflect its purpose. * The [`gpu`](/mojo/std/gpu) package API documentation has been expanded, and API documentation for the [`layout`](/mojo/kernels/layout) package is underway, beginning with core types, functions, and traits. See the [Standard library changes](#25-2-standard-library-changes) section of the changelog for more information. * The legacy `borrowed`/`inout` keywords and `-> T as foo` syntax are no longer supported and now generate a compiler error. Please move to `read`/`mut`/`out` argument syntax instead. See [Argument conventions](/mojo/manual/values/ownership#argument-conventions) in the Mojo Manual for more information. * The standard library has many changes related to strings. Notably, the `Char` type has been renamed to [`Codepoint`](/mojo/std/collections/string/codepoint/Codepoint), to better capture its intended purpose of storing a single Unicode codepoint. Additionally, related method and type names have been updated as well. See [Standard library changes](#25-2-standard-library-changes) for more details. * Support has been added for 128- and 256-bit signed and unsigned integers. This includes the [`DType`](/mojo/std/builtin/dtype/DType) aliases `DType.int128`, `DType.uint128`, `DType.int256`, and `DType.uint256`, as well as [`SIMD`](/mojo/std/builtin/simd/SIMD) support for 128- and 256-bit signed and unsigned element types. Note that this exposes capabilities (and limitations) of LLVM, which may not always provide high performance for these types and may have missing operations like divide, remainder, etc. See [Standard library changes](#25-2-standard-library-changes) for more details. ### Language changes {#25-2-language-changes} * References to aliases in struct types with unbound (or partially) bound parameters sets are now allowed as long as the referenced alias doesn't depend on any unbound parameters: ```mojo struct StructWithParam[a: Int, b: Int]: alias a1 = 42 alias a2 = a+1 fn test(): _ = StructWithParams.a1 # ok _ = StructWithParams[1].a2 # ok _ = StructWithParams.a2 # error, 'a' is unbound. ``` * The Mojo compiler now warns about `@parameter for` with large loop unrolling factor (>1024 by default), which can lead to long compilation time and large generated code size. Set `--loop-unrolling-warn-threshold` to change default value to a different threshold or to `0` to disable the warning. * The Mojo compile-time interpreter can now handle many more LLVM intrinsics, including ones that return floating point values. This allows functions like [`round()`](/mojo/std/builtin/math/round) to be constant folded when used in a compile-time context. * The Mojo compiler now has only one compile-time interpreter. It had two previously: one to handle a few cases that were important for dependent types in the parser (but which also had many limitations), and the primary one that ran at "instantiation" time which is fully general. This was confusing and caused a wide range of bugs. We've now removed the special case parse-time interpreter, replacing it with a more general solution for dependent types. This change should be invisible to most users, but should resolve a number of long-standing bugs and significantly simplifies the compiler implementation, allowing us to move faster. ### Standard library changes {#25-2-standard-library-changes} * [`Optional`](/mojo/std/collections/optional/Optional), [`Span`](/mojo/std/memory/span/Span), and [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) have been added to the prelude. You now no longer need to explicitly import these types to use them in your program. * GPU programming changes: * You can now skip compiling a GPU kernel first before enqueueing it, and pass it directly to [`DeviceContext.enqueue_function()`](/mojo/std/gpu/host/device_context/DeviceContext#enqueue_function): ```mojo from gpu.host import DeviceContext fn func(): print("Hello from GPU") with DeviceContext() as ctx: ctx.enqueue_function[func](grid_dim=1, block_dim=1) ``` However, if you're reusing the same function and parameters multiple times, this incurs some overhead of around 50-500 nanoseconds per enqueue. So you can still compile the function first with [`DeviceContext.compile_function()`](/mojo/std/gpu/host/device_context/DeviceContext#compile_function) and pass it to `DeviceContext.enqueue_function()` like this: ```mojo with DeviceContext() as ctx: var compiled_func = ctx.compile_function[func]() # Multiple kernel launches with the same function/parameters ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1) ``` * The following methods on [`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext): * `enqueue_copy_to_device()` * `enqueue_copy_from_device()` * `enqueue_copy_device_to_device()` have been combined to a single overloaded [`enqueue_copy()`](/mojo/std/gpu/host/device_context/DeviceContext/#enqueue_copy) method. Additionally, the methods: * `copy_to_device_sync()` * `copy_from_device_sync()` * `copy_device_to_device_sync()` have been combined into an overloaded [`copy_sync()`](/mojo/std/gpu/host/device_context/DeviceContext/#copy_sync) method. * The `gpu.shuffle` module has been renamed to [`gpu.warp`](/mojo/std/gpu/primitives/warp/) to better reflect its purpose. For example: ```mojo import gpu.warp as warp var val0 = warp.shuffle_down(x, offset) var val1 = warp.broadcast(x) ``` * Support has been added for 128- and 256-bit signed and unsigned integers. * The following aliases have been added to the [`DType`](/mojo/std/builtin/dtype/DType) struct: `DType.int128`, `DType.uint128`, `DType.int256`, and `DType.uint256`. * The [`SIMD`](/mojo/std/builtin/simd/SIMD) type now supports 128- and 256-bit signed and unsigned element types. Note that this exposes capabilities (and limitations) of LLVM, which may not always provide high performance for these types and may have missing operations like divide, remainder, etc. * The following [`Scalar`](/mojo/std/builtin/simd/#aliases) aliases for 1-element `SIMD` values have been added: `Int128`, `UInt128`, `Int256`, and `UInt256`. * [`String`](/mojo/std/collections/string) and friends: * The `Char` type has been renamed to [`Codepoint`](/mojo/std/collections/string/codepoint/Codepoint), to better capture its intended purpose of storing a single Unicode codepoint. Additionally, related method and type names have been updated as well, including: * `StringSlice.chars()` and `String.chars()` to [`StringSlice.codepoints()`](/mojo/std/collections/string/string_slice/StringSlice/#codepoints) and [`String.codepoints()`](/mojo/std/collections/string/string/String/#codepoints), respectively * `StringSlice.char_slices()` and `String.char_slices()` to [`StringSlice.codepoint_slices()`](/mojo/std/collections/string/string_slice/StringSlice/#codepoint_slices) and [`String.codepoint_slices()`](/mojo/std/collections/string/string/String/#codepoint_slices), respectively * `CharsIter` to [`CodepointsIter`](/mojo/std/collections/string/string_slice/CodepointsIter) * `Char.unsafe_decode_utf8_char()` to [`Codepoint.unsafe_decode_utf8_codepoint()`](/mojo/std/collections/string/codepoint/Codepoint/#unsafe_decode_utf8_codepoint) * Made the iterator type returned by the string `codepoint_slices()` methods public as [`CodepointSliceIter`](/mojo/std/collections/string/string_slice/CodepointSliceIter/). * [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice) now supports several additional methods moved from [`String`](/mojo/std/collections/string/string/String). The existing `String` methods have been updated to instead call the corresponding new `StringSlice` methods: * [`center()`](/mojo/std/collections/string/string_slice/StringSlice/#center) * [`is_ascii_digit()`](/mojo/std/collections/string/string_slice/StringSlice/#is_ascii_digit) * [`is_ascii_printable()`](/mojo/std/collections/string/string_slice/StringSlice/#is_ascii_printable) * [`islower()`](/mojo/std/collections/string/string_slice/StringSlice/#islower) * [`isupper()`](/mojo/std/collections/string/string_slice/StringSlice/#isupper) * [`ljust()`](/mojo/std/collections/string/string_slice/StringSlice/#ljust) * [`lower()`](/mojo/std/collections/string/string_slice/StringSlice/#lower) * [`rjust()`](/mojo/std/collections/string/string_slice/StringSlice/#rjust) * [`split()`](/mojo/std/collections/string/string_slice/StringSlice/#split) * [`upper()`](/mojo/std/collections/string/string_slice/StringSlice/#upper) * Added a [`StringSlice.is_codepoint_boundary()`](/mojo/std/collections/string/string_slice/StringSlice/#is_codepoint_boundary) method for querying if a given byte index is a boundary between encoded UTF-8 codepoints. * [`StringSlice.__getitem__(Slice)`](/mojo/std/collections/string/string_slice/StringSlice/#__getitem__) now raises an error if the provided slice start and end positions do not fall on a valid codepoint boundary. This prevents construction of malformed `StringSlice` values, which could lead to memory unsafety or undefined behavior. For example, given a string containing multi-byte encoded data, like: ```mojo str_slice = "Hi👋!" ``` and whose in-memory and decoded data looks like: String Hi👋! Codepoint Characters H i 👋 ! Codepoints 72 105 128075 33 Bytes 72 105 240 159 145 139 33 Index 0 1 2 3 4 5 6 attempting to slice bytes `[3-5)` with `str_slice[3:5]` would previously erroneously produce a malformed `StringSlice` as output that did not correctly decode to anything: String invalid Codepoint Characters invalid Codepoints invalid Bytes 159 145 Index 0 1 The same statement will now raise an error informing the user that their indices are invalid. * The `StringLiteral.get[value]()` method, which converts a compile-time value of [`Stringable`](/mojo/std/builtin/str/Stringable) type has been changed to a function named [`get_string_literal[value]()`](/mojo/std/builtin/string_literal/get_string_literal). * Collections: * A new [`IntervalTree`](/mojo/std/collections/interval/IntervalTree) data structure has been added to the standard library. This is a tree data structure that allows for efficient range queries. * Added an iterator to [`LinkedList`](/mojo/std/collections/linked_list/LinkedList) ([PR \#4005](https://github.com/modular/modular/pull/4005)) * [`LinkedList.__iter__()`](/mojo/std/collections/linked_list/LinkedList/#__iter__) to create a forward iterator. * [`LinkedList.__reversed__()`](/mojo/std/collections/linked_list/LinkedList/#__reversed__) for a backward iterator. ```mojo var ll = LinkedList[Int](1, 2, 3) for element in ll: print(element[]) ``` * `List.bytecount()` has been renamed to [`List.byte_length()`](/mojo/std/collections/list/List/#byte_length) for consistency with the string-like APIs. * The [`InlineArray(unsafe_uninitialized=True)`](/mojo/std/collections/inline_array/InlineArray/#__init__) constructor is now spelled `InlineArray(uninitialized=True)`. * The design of the [`IntLiteral`](/mojo/std/builtin/int_literal/IntLiteral) and [`FloatLiteral`](/mojo/std/builtin/float_literal/FloatLiteral) types has been changed to maintain their compile-time-only value as a parameter instead of a stored field. This correctly models that infinite precision literals are not representable at runtime, and eliminates a number of bugs hit in corner cases. This is made possible by enhanced dependent type support in the compiler. * The `Buffer` struct has been removed in favor of [`Span`](/mojo/std/memory/span/Span) and `NDBuffer`. * The [`round()`](/mojo/std/builtin/math/round) function is now fixed to perform "round half to even" (also known as "bankers' rounding") instead of "round half away from zero". * The [`UnsafePointer.alloc()`](/mojo/std/memory/unsafe_pointer/UnsafePointer/#alloc) method has changed to produce pointers with an empty `Origin` parameter, instead of with `MutableAnyOrigin`. This mitigates an issue with the any origin parameter extending the lifetime of unrelated local variables for this common method. * Several more packages are now documented: * [`compile`](/mojo/std/compile) package * [`gpu`](/mojo/std/gpu) package * [`layout`](/mojo/kernels/layout) package is underway, beginning with core types, functions, and traits * Added a new [`sys.is_compile_time()`](/mojo/std/sys/compile/is_compile_time) function. This enables you to query whether code is being executed at compile time or not. For example: ```mojo from sys import is_compile_time fn check_compile_time() -> String: if is_compile_time(): return "compile time" else: return "runtime" def main(): alias var0 = check_compile_time() var var1 = check_compile_time() print("var0 is evaluated at ", var0, " , while var1 is evaluated at ", var1) ``` will print `var0 is evaluated at compile time, while var1 is evaluated at runtime`. ### Tooling changes {#25-2-tooling-changes} * Mojo API documentation generation is now able to display function and struct parameter references inside nested parametric types using names instead of indices. For example, instead of ```mojo sort[type: CollectionElement, //, cmp_fn: fn($1|0, $1|0) capturing -> Bool](span: Span[type, origin]) ``` it now displays ```mojo sort[type: CollectionElement, //, cmp_fn: fn(type, type) capturing -> Bool](span: Span[type, origin]) ``` ### ❌ Removed * Use of legacy argument conventions like `inout` and the use of `as` in named results now produces an error message instead of a warning. * Direct access to `List.size` has been removed. Use the public API instead. Examples: Extending a List: ```mojo base_data = List[Byte](1, 2, 3) data_list = List[Byte](4, 5, 6) ext_data_list = base_data.copy() ext_data_list.extend(data_list) # [1, 2, 3, 4, 5, 6] data_span = Span(List[Byte](4, 5, 6)) ext_data_span = base_data.copy() ext_data_span.extend(data_span) # [1, 2, 3, 4, 5, 6] data_vec = SIMD[DType.uint8, 4](4, 5, 6, 7) ext_data_vec_full = base_data.copy() ext_data_vec_full.extend(data_vec) # [1, 2, 3, 4, 5, 6, 7] ext_data_vec_partial = base_data.copy() ext_data_vec_partial.extend(data_vec, count=3) # [1, 2, 3, 4, 5, 6] ``` Slicing and extending a list efficiently: ```mojo base_data = List[Byte](1, 2, 3, 4, 5, 6) n4_n5 = Span(base_data)[3:5] extra_data = Span(List[Byte](8, 10)) end_result = List[Byte](capacity=len(n4_n5) + len(extra_data)) end_result.extend(n4_n5) end_result.extend(extra_data) # [4, 5, 8, 10] ``` * `InlinedFixedVector` and `InlineList` have been removed. Instead, use [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) when the upper bound is known at compile time. If the upper bound is not known until runtime, use [`List`](/mojo/std/collections/list/List) with the `capacity` constructor to minimize allocations. ### 🛠️ Fixed * [#3976](https://github.com/modular/modular/issues/3976) The `variance` argument in [`random.randn_float64()`](/mojo/std/random/random/randn_float64) and [`random.randn()`](/mojo/std/random/random/randn) has been renamed to `standard_deviation` so that values are drawn from the correct distribution. ### Special thanks Special thanks to our community contributors: [@bgreni](https://github.com/bgreni), [@fnands](https://github.com/fnands), [@illiasheshyn](https://github.com/illiasheshyn), [@izo0x90](https://github.com/izo0x90), [@lydiandy](https://github.com/lydiandy), [@martinvuyk](https://github.com/martinvuyk), [@msaelices](https://github.com/msaelices), [@owenhilyard](https://github.com/owenhilyard), [@rd4com](https://github.com/rd4com), [@yinonburgansky](https://github.com/yinonburgansky) ## v25.1 (2025-02-13) ### ✨ Highlights * The legacy `borrowed`/`inout` keywords and `-> T as foo` syntax are deprecated and now generate a compiler warning. Please move to `read`/`mut`/`out` argument syntax instead. See [Argument conventions](/mojo/manual/values/ownership#argument-conventions) in the Mojo Manual for more information. * The `bool()`, `float()`, `int()`, and `str()` functions are deprecated and generate compiler warnings. Please use the `Bool()`, `Float64()`, `Int()`, and `String()` constructors instead. See [Standard library changes](#25-1-standard-library-changes) for more details. * The standard library has many changes related to strings. The new [`Char`](/mojo/std/collections/string/codepoint/Codepoint) struct represents a single Unicode character, and includes several methods for categorizing character types. When iterating over the characters of a `String` with a `for` loop, you now should use the [`String.chars()`](/mojo/std/collections/string/string/String#chars) method to provide an iterator of `Char` values or the [`String.char_slices()`](/mojo/std/collections/string/string/String#char_slices) method to provide an iterator of [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/) instances for each character. `StringRef` has been removed in favor of [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/). And various functionality has moved from `String` and `StringLiteral` to the more general `StringSlice` type. See [Standard library changes](#25-1-standard-library-changes) for more details. * You can now use [`SIMD`](/mojo/std/builtin/simd/SIMD) constructors to cast existing `SIMD` values (including `Scalar` values) to a different type, though you can still use the [`SIMD.cast()`](/mojo/std/builtin/simd/SIMD#cast) method to infer the size of the new vector. See [Standard library changes](#25-1-standard-library-changes) for more details. ### Language changes {#25-1-language-changes} * The legacy `borrowed`/`inout` keywords and `-> T as foo` syntax now generate a warning. Please move to `read`/`mut`/`out` argument syntax instead. See [Argument conventions](/mojo/manual/values/ownership#argument-conventions) in the Mojo Manual for more information. * Initializers are now treated as static methods that return an instance of `Self`. This means the `out` argument of an initializer is now treated the same as any other function result or `out` argument. This is generally invisible, except that patterns like `instance.__init__()` and `x.__copyinit__(y)` no longer work. Simply replace them with `instance = T()` and `x = y` respectively. * The [`@value`](/mojo/manual/decorators/value) decorator now additionally derives an implementation of the `ExplicitlyCopyable` trait. This will ease the transition to explicit copyability requirements by default in the Mojo collection types. * Indexing into a homogenous tuple now produces the consistent element type without needing a rebind: ```mojo var x = (1, 2, 3, 3, 4) var y : Int = x[idx] # Just works! ``` * You can now overload positional arguments with a keyword-only argument, and keyword-only arguments with different names: ```mojo struct OverloadedKwArgs: var val: Int fn __init__(out self, single: Int): self.val = single fn __init__(out self, *, double: Int): self.val = double * 2 fn __init__(out self, *, triple: Int): self.val = triple * 3 fn main(): OverloadedKwArgs(1) # val=1 OverloadedKwArgs(double=1) # val=2 OverloadedKwArgs(triple=2) # val=6 ``` This also works with indexing operations: ```mojo struct OverloadedKwArgs: var vals: List[Int] fn __init__(out self): self.vals = List[Int](0, 1, 2) fn __getitem__(self, idx: Int) -> Int: return self.vals[idx] fn __getitem__(self, *, idx2: Int) -> Int: return self.vals[idx2 * 2] fn __setitem__(mut self, idx: Int, val: Int): self.vals[idx] = val fn __setitem__(mut self, val: Int, *, idx2: Int): self.vals[idx2 * 2] = val fn main(): var x = OverloadedKwArgs() print(x[1]) # 1 print(x[idx2=1]) # 2 x[1] = 42 x[idx2=1] = 84 print(x[1]) # 42 print(x[idx2=1]) # 84 ``` * The `__disable_del x` operation has been tightened up to treat all fields of `x` as consumed by the point of the deletion, so it should be used after all the subfields are transferred or otherwise consumed (for example, at the end of the function), not before uses of the fields. ### GPU programming {#25-1-gpu-programming} * The new [`gpu` package](/mojo/std/gpu/) provides low-level programming constructs for working with GPUs. The Mojo `gpu` APIs allow you to manually manage interaction between the CPU host and GPU device, manage memory between devices, synchronize threads, and more. Currently the best way to use these APIs is from inside a [MAX custom operation](/max/develop/custom-ops). The following code example shows a GPU kernel written in Mojo: ```mojo from max.tensor import ManagedTensorSlice from gpu import thread_idx, block_dim, block_idx fn gpu_add_kernel(out: ManagedTensorSlice, x: ManagedTensorSlice[out.type, out.rank]): tid_x = thread_idx.x + block_dim.x * block_idx.x tid_y = thread_idx.y + block_dim.y * block_dim.y if tid_x < x.dim_size(0) and tid_y < x.dim_size(1): out[tid_x, tid_y] = x[tid_x, tid_y] + 1 ``` The example above includes only the actual kernel code that’s run on the GPU, not the code to define a custom operation or launch the kernel. For more complete examples, see [`vector_addition.mojo`](https://github.com/modular/modular/blob/main/max/examples/custom_ops/kernels/vector_addition.mojo) and [`top_k.mojo`](https://github.com/modular/modular/blob/main/max/examples/custom_ops/kernels/top_k.mojo). * The [`layout` package](/mojo/kernels/layout/) includes APIs for working with *layouts*, which describe the organization of a tensor (for example, row-major or column-major layout), and the [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) type, which represents a tensor with a specified layout. The `layout` package can be used to build efficient tensor operations that run on a GPU. We’ll continue adding code examples and documentation for the `gpu` and `layout` packages in future releases. ### Standard library changes {#25-1-standard-library-changes} * The builtin functions for converting values to different types have been deprecated for actual constructors: | **Before** | **After** | | ---------- | ----------- | | `bool()` | `Bool()` | | `float()` | `Float64()` | | `int()` | `Int()` | | `str()` | `String()` | These functions were a workaround before Mojo had a way to distinguish between implicit and explicit constructors. For this release you'll get a deprecation warning, and in the next release they'll become compiler errors. You can quickly update your code by doing a `Match Case` and `Match Whole Word` search and replace for `int(` to `Int(` etc. * [`String`](/mojo/std/collections/string/string) and friends: * Added [`Char`](/mojo/std/collections/string/codepoint/Codepoint) for representing and storing single Unicode characters. * `Char` implements `CollectionElement`, [`EqualityComparable`](/mojo/std/builtin/comparable/#equalitycomparable), [`Intable`](/mojo/std/builtin/int/Intable/), and [`Stringable`](/mojo/std/builtin/str/Stringable/). * `Char` provides methods for categorizing character types, including: [`Char.is_ascii()`](/mojo/std/collections/string/codepoint/Codepoint/#is_ascii), [`Char.is_ascii_digit()`](/mojo/std/collections/string/codepoint/Codepoint/#is_ascii_digit), [`Char.is_ascii_upper()`](/mojo/std/collections/string/codepoint/Codepoint/#is_ascii_upper), [`Char.is_ascii_lower()`](/mojo/std/collections/string/codepoint/Codepoint/#is_ascii_lower), [`Char.is_ascii_printable()`](/mojo/std/collections/string/codepoint/Codepoint/#is_ascii_printable), [`Char.is_posix_space()`](/mojo/std/collections/string/codepoint/Codepoint/#is_posix_space), [`Char.is_python_space()`](/mojo/std/collections/string/codepoint/Codepoint/#is_python_space). * Added a `String()` constructor from `Char`. * `Char` can be converted to `UInt32` via [`Char.to_u32()`](/mojo/std/collections/string/codepoint/Codepoint/#to_u32). * [`chr()`](/mojo/std/collections/string/string/chr/) will now abort if given a codepoint value that is not a valid `Char`. * `StringRef` has been removed in favor of [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/). The two types are ABI compatible, and for the exact same behavior one can use `StaticString`, which is an alias to `StringSlice[StaticConstantOrigin]`. * Various functionality has moved from `String` and `StringLiteral` to the more general `StringSlice` type. * Added [`StringSlice.from_utf8()`](/mojo/std/collections/string/string_slice/StringSlice/#from_utf8) factory method, for validated construction of a `StringSlice` from a buffer containing UTF-8 encoded data. This method will raise if the buffer contents are not valid UTF-8. * Added [`StringSlice.chars()`](/mojo/std/collections/string/string_slice/StringSlice/#chars) which returns an iterator over `Char`s. This is a compliant UTF-8 decoder that returns each Unicode codepoint encoded in the string. * Added [`StringSlice.__getitem__(Slice)`](/mojo/std/collections/string/string_slice/StringSlice/#__getitem__) which returns a substring. Only step sizes of 1 are supported. * Several standard library functions have been changed to take `StringSlice` instead of `String`. This generalizes them to be used for any appropriately encoded string in memory, without requiring that the string be heap allocated. This includes: [`ascii()`](/mojo/std/collections/string/string/ascii), [`atol()`](/mojo/std/collections/string/string/atol), [`atof()`](/mojo/std/collections/string/string/atof), [`b16decode()`](/mojo/std/base64/base64/b16decode), [`b16encode()`](/mojo/std/base64/base64/b16encode), [`b64decode()`](/mojo/std/base64/base64/b64decode), [`b64encode()`](/mojo/std/base64/base64/b64encode), and [`ord()`](/mojo/std/collections/string/string/ord). * Added new [`String.chars()`](/mojo/std/collections/string/string/String/#chars) and [`String.char_slices()`](/mojo/std/collections/string/string/String/#char_slices) iterator methods, and deprecated the existing `String.__iter__()` method. Different use-cases may prefer iterating over the `Char`s encoded in a string, or iterating over subslices containing single characters. Neither iteration semantics is an obvious default, so the existing `__iter__()` method has been deprecated in favor of writing explicit iteration methods for the time being. Code of the form: ```mojo var s: String = ... for c in s: # ... ``` can be migrated to using the `.char_slices()` method: ```mojo var s: String = ... for c in s.char_slices(): # ... ``` * Added [`StringSlice.char_length()`](/mojo/std/collections/string/string_slice/StringSlice/#char_length) method, to pair with the existing [`StringSlice.byte_length()`](/mojo/std/collections/string/string_slice/StringSlice/#byte_length) method. * The [`String.__len__()`](/mojo/std/collections/string/string/String/#__len__) and [`StringSlice.__len__()`](/mojo/std/collections/string/string_slice/StringSlice/#__len__) methods now return the length of the string in bytes. Previously, these methods were documented to note that they would eventually return a length in Unicode codepoints. They have been changed to guarantee a length in bytes, since the length in bytes is how they are most often used today (for example, as bounds to low-level memory manipulation logic). Additionally, length in codepoints is a more specialized notion of string length that is rarely the correct metric. Users that know they need the length in codepoints can use the `str.char_length()` method, or `len(str.chars())`. * `StringSlice` now implements [`Representable`](/mojo/std/builtin/repr/Representable/), and that implementation is now used by `String.__repr__()` and `StringLiteral.__repr__()`. * `StringSlice` now implements [`EqualityComparable`](/mojo/std/builtin/comparable/#equalitycomparable). Up until now, `StringSlice` has implemented a more general `__eq__()` and `__ne__()` comparison with `StringSlice` types that had arbitrary other origins. However, to satisfy `EqualityComparable`, `StringSlice` now also has narrower comparison methods that support comparing only with another `StringSlice` with the exact same origin. * The `String.write()` static method has moved to a `String()` constructor, and is now buffered. Instead of doing: ```mojo var msg = "my message " + String(x) + " " + String(y) + " " + String(z) ``` Which reallocates the `String` you should do: ```mojo var msg = String("my message", x, y, z, sep=" ") ``` Which is cleaner, and buffers to the stack so the `String` is allocated only once. * You can now pass any [`Writer`](/mojo/std/format/Writer/) to `write_buffered()`: ```mojo from utils.write import write_buffered var string = String("existing string") write_buffered(string, 42, 42.4, True, sep=" ") ``` This writes to a buffer on the stack before reallocating the `String`. * Collections: * A new [`LinkedList`](/mojo/std/collections/linked_list/LinkedList/) type has been added to the standard library. * Added [`Optional.copied()`](/mojo/std/collections/optional/Optional#copied) for constructing an owned `Optional[T]` from an `Optional[Pointer[T]]` by copying the pointee value. * Added [`Dict.get_ptr()`](/mojo/std/collections/dict/Dict#get_ptr) which returns an `Optional[Pointer[V]]`. If the given key is present in the dictionary, the optional will hold a pointer to the value. Otherwise, an empty optional is returned. * Added new [`List.extend()`](/mojo/std/collections/list/List#extend) overloads taking [`SIMD`](/mojo/std/builtin/simd/SIMD) and [`Span`](/mojo/std/memory/span/Span/). These enable growing a `List[Scalar[..]]` by copying the elements of a `SIMD` vector or `Span[Scalar[..]]`, simplifying the writing of some optimized SIMD-aware functionality. * [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer/) changes: * `UnsafePointer`'s `bitcast()` method has now been split into [`bitcast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#bitcast) for changing the type, [`origin_cast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#origin_cast) for changing mutability, [`static_alignment_cast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#static_alignment_cast) for changing alignment, and [`address_space_cast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#address_space_cast) for changing the address space. * `UnsafePointer` is now parameterized on mutability. Previously, `UnsafePointer` could only represent mutable pointers. The new `mut` parameter can be used to restrict an `UnsafePointer` to a specific mutability: `UnsafePointer[T, mut=False]` represents a pointer to an immutable `T` value. This is analogous to a `const *` pointer in C++. * [`UnsafePointer.address_of()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#address_of) will now infer the origin and mutability of the resulting pointer from the argument. For example: ```mojo var local = 10 # Constructs a mutable pointer, because `local` is a mutable memory location var ptr = UnsafePointer.address_of(local) ``` To force the construction of an immutable pointer to an otherwise mutable memory location, use a cast: ```mojo var local = 10 # Cast the mutable pointer to be immutable. var ptr = UnsafePointer.address_of(local).origin_cast[mut=False]() ``` * The `unsafe_ptr()` method on several standard library collection types have been updated to use parametric mutability: they will return an `UnsafePointer` whose mutability is inherited from the mutability of the `ref self` of the receiver at the call site. For example, `ptr1` will be immutable, while `ptr2` will be mutable: ```mojo fn take_lists(read list1: List[Int], mut list2: List[Int]): # Immutable pointer, since receiver is immutable `read` reference var ptr1 = list1.unsafe_ptr() # Mutable pointer, since receiver is mutable `mut` reference var ptr2 = list2.unsafe_ptr() ``` * New and updated traits: * The `ExplicitlyCopyable` trait has changed to require a `fn copy(self) -> Self` method. Previously, an initializer with the signature `fn __init__(out self, *, other: Self)` had been required by `ExplicitlyCopyable`. This improves the "greppability" and at-a-glance readability when a programmer is looking for places in their code that may be performing copies. * The `IntLike` trait has been removed and its functionality incorporated into the [`Indexer`](/mojo/std/builtin/int/Indexer/) trait. This enables `SIMD` scalar integer types and `UInt` to be used for indexing into all of the collection types, as well as optimizing away normalization checks for `UInt` indexing. * The `ImplicitlyIntable` trait has been added, allowing types to be implicitly converted to an `Int` by implementing the `__as_int__()` method: ```mojo @value struct Foo(ImplicitlyIntable): var i: Int fn __as_int__(self) -> Int: return self.i ``` * You can now cast `SIMD` types using constructors: ```mojo var val = Int8(42) var cast = Int32(val) ``` It also works when passing a scalar type to larger vector size: ```mojo var vector = SIMD[DType.int64, 4](cast) # [42, 42, 42, 42] ``` For values other than scalars the size of the `SIMD` vector needs to be equal: ```mojo var float_vector = SIMD[DType.float64, 4](vector) ``` [`SIMD.cast()`](/mojo/std/builtin/simd/SIMD#cast) still exists to infer the size of new vector: ```mojo var inferred_size = float_vector.cast[DType.uint64]() # [42, 42, 42, 42] ``` * Added [`SIMD.from_bytes()`](/mojo/std/builtin/simd/SIMD/#from_bytes) and [`SIMD.as_bytes()`](/mojo/std/builtin/simd/SIMD/#as_bytes) to convert a list of bytes to a list of scalars and vice versa, accepting the endianness as an argument. Similar to Python `int.from_bytes()` and `int.to_bytes()` functions. * You can now use [`max()`](/mojo/std/builtin/math/max) and [`min()`](/mojo/std/builtin/math/min) with variadic number of arguments. * `bit_ceil()` has been renamed to [`next_power_of_two()`](/mojo/std/bit/bit/next_power_of_two), and `bit_floor()` to [`prev_power_of_two()`](/mojo/std/bit/bit/prev_power_of_two). This is to improve readability and clarity in their use. * Added a new boolean `validate` parameter to [`b64decode()`](/mojo/std/base64/base64/b64decode). * The [`b64encode()`](/mojo/std/base64/base64/b64encode) overload that previously took a `List` has been changed to take a [`Span`](/mojo/std/memory/span/Span/). * Removed the `@implicit` decorator from some standard library initializer methods that perform allocation. This reduces places where Mojo code could implicitly allocate where the user may not be aware. Removed `@implicit` from: * `String.__init__(out self, StringSlice)` * `List.__init__(out self, owned *values: T)` * `List.__init__(out self, span: Span[T])` * Added more aliases in [`sys.ffi`](/mojo/std/sys/ffi/) to round out the usual needs for FFI bindings. ### Tooling changes {#25-1-tooling-changes} * `mblack` (aka [`mojo format`](/mojo/cli/format)) no longer formats non-Mojo files. This prevents unexpected formatting of Python files. * Full struct signature information is now exposed in the documentation generator, and in the symbol outline and hover markdown via the Mojo Language Server. * The [`env_get_dtype()`](/mojo/std/sys/param_env/env_get_dtype) function has been added to the [`sys.param_env`](/mojo/std/sys/param_env/) module. This allows you to get the value of a `DType` from the param environment. ### ❌ Removed * `StringRef` has been removed. Use [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice/) instead. * Changed [`sys.argv()`](/mojo/std/sys/arg/argv) to return list of `StringSlice`. * Added explicit [`Path()`](/mojo/std/pathlib/path/Path/#__init__) constructor from `StringSlice`. * The `Tuple.get[i, T]()` method has been removed. Please use `tup[i]` or `rebind[T](tup[i])` as needed instead. * `StringableCollectionElement` is deprecated. Use `WritableCollectionElement` instead, which still allows you to construct a `String`, but can avoid intermediate allocations. * The `IntLike` trait has been removed and its functionality incorporated into the [`Indexer`](/mojo/std/builtin/int/Indexer/) trait. * The `Type{field1: 42, field2: 17}` syntax for direct initializing register passable types has been removed. This was legacy syntax - to upgrade your code, add the [`@value`](/mojo/manual/decorators/value) decorator to your struct to get a fieldwise initializer and use `Type(field1=42, field2 = 17)` instead. ### 🛠️ Fixed * The Mojo Kernel for Jupyter Notebooks is working again on nightly releases. * The command `mojo debug --vscode` now sets the current working directory properly. * [Issue #3796](https://github.com/modular/modular/issues/3796) - Compiler crash handling `for`-`else` statement. * [Issue #3540](https://github.com/modular/modular/issues/3540) - Using named output slot breaks trait conformance * [Issue #3617](https://github.com/modular/modular/issues/3617) - Can't generate the constructors for a type wrapping `!lit.ref` * The Mojo Language Server doesn't crash anymore on empty `__init__.mojo` files. [Issue #3826](https://github.com/modular/modular/issues/3826). * [Issue #3935](https://github.com/modular/modular/issues/3935) - Confusing OOM error when using `Tuple.get()` incorrectly. * [Issue #3955](https://github.com/modular/modular/issues/3955) - Unexpected copy behavior with `def` arguments in loops * [Issue #3960](https://github.com/modular/modular/issues/3960) - Infinite `for` loop ## v24.6 (2024-12-17) ### ✨ Highlights Here's a brief summary of some of the major changes in this release, with more detailed information in the following sections: * The `inout` and `borrowed` argument conventions have been renamed to `mut` and `read`, respectively. A new `out` convention has been added for the `self` argument in constructors and for named results. See [Language changes](#24-6-language-changes) for details. * `Lifetime` and related types in the standard library have been renamed to [`Origin`](/mojo/std/builtin/type_aliases/Origin) to better clarify that parameters of this type indicate where a reference is derived from, not the more complicated notion of where a variable is initialized and destroyed. As a consequence the `__lifetime_of()` operator is now named `__origin_of()`. There are also a number of other origin-related improvements in this release, including being able to specify a union of origins by listing multiple values in the `__origin_of()` operator or inside the `ref` origin specifier (`ref [a, b]`). For details, see [Language changes](#24-6-language-changes). For background information and rationale on the name change see [the proposal](https://github.com/modular/modular/issues/3623). For more information on origins, see [Lifetimes, origins and references](/mojo/manual/values/lifetimes) in the Mojo Manual. * Implicit conversions are now opt-in using the [`@implicit`](/mojo/manual/decorators/implicit) decorator. See [Language changes](#24-6-language-changes) for details. * The standard library has added several new types, including [`Deque`](/mojo/std/collections/deque/Deque) (a double-ended queue) and [`OwnedPointer`](/mojo/std/memory/owned_pointer/OwnedPointer) (safe, single-owner, non-nullable smart pointer). See [Standard library changes](#24-6-standard-library-changes) for details. * The VS Code extension now supports setting data breakpoints and function breakpoints, and the Mojo LLDB debugger supports symbol breakpoints, such as `b main` or `b my_module::main`. * We've made a number of improvement to how information is displayed in error messages, LSP, and generated API documentation. For details, see [Tooling changes](#24-6-tooling-changes). * And we've added a number of new docs, including a brand new [Mojo tutorial](/mojo/manual/get-started), new pages on [operators and expressions](/mojo/manual/operators), [error handling](/mojo/manual/errors), and [pointers](/mojo/manual/pointers/), and many smaller additions and improvements. ### Language changes {#24-6-language-changes} * Argument convention changes: * The `inout` and `borrowed` argument conventions have been renamed to `mut` (for "mutate") and `read`, respectively. These verbs reflect what the callee can do to the argument value passed in by the caller, without requiring the programmer to know about advanced features like references. For information on Mojo's argument conventions, see [Argument conventions](/mojo/manual/values/ownership/#argument-conventions) in the Mojo Manual. * The argument convention for the `self` argument in the `__init__()`, `__copyinit__()`, and `__moveinit__()` methods has been changed from `inout` to `out`, reflecting that a constructor method initializes its `self` value without reading from it. This also enables spelling the type of an initializer correctly, which was not supported before: ```mojo struct Foo: fn __init__(out self): pass fn test(): # This works now var fnPtr : fn(out x: Foo)->None = Foo.__init__ var someFoo : Foo fnPtr(someFoo) # initializes someFoo. ``` The previous `fn __init__(inout self)` syntax is still supported in this release of Mojo, but will be removed in the future. Please migrate to the new syntax. * Similarly, the spelling of named results has switched to use `out` syntax instead of `-> T as name`. Functions may have at most one named result or return type specified with the usual `->` syntax. `out` arguments may occur anywhere in the argument list, but are typically last (except for `__init__` methods, where they are typically first). ```mojo # This function has type "fn() -> String" fn example(out result: String): result = "foo" ``` The parser still accepts the old syntax as a synonym for this, but that will eventually be deprecated and removed. This was [discussed extensively in a public proposal](https://github.com/modular/modular/issues/3623). For more information, see [Named results](/nightly/mojo/manual/functions#named-results) in the Mojo Manual. * Single argument constructors now require the [`@implicit`](/mojo/manual/decorators/implicit) decorator to allow for implicit conversions. Previously you could define an `__init__` that takes a single argument: ```mojo struct Foo: var value: Int fn __init__(out self, value: Int): self.value = value ``` And this would allow you to pass an `Int` in the position of a `Foo`: ```mojo fn func(foo: Foo): print("implicitly converted Int to Foo:", foo.value) fn main(): func(Int(42)) ``` This can result in complicated errors that are difficult to debug. By default this implicit behavior is now turned off, so you have to explicitly construct `Foo`: ```mojo fn main(): func(Foo(42)) ``` You can still opt into implicit conversions by adding the `@implicit` decorator. For example, to enable implicit conversions from `Int` to `Foo`: ```mojo struct Foo: var value: Int @implicit fn __init__(out self, value: Int): self.value = value ``` For more information see [Constructors and implicit conversion](/mojo/manual/lifecycle/life#constructors-and-implicit-conversion) in the Mojo Manual. * Origin-related changes: * The `AnyLifetime` type (useful for declaring origin types as parameters) has has been renamed to [`Origin`](/mojo/std/builtin/type_aliases/Origin) and the `__lifetime_of()` operator renamed to `__origin_of()`. * `Origin` is now a complete wrapper around the MLIR origin type. * The `Origin.type` alias has been renamed to `_mlir_origin`. In parameter lists, you can now write just `Origin[..]`, instead of `Origin[..].type`. * `ImmutableOrigin` and `MutableOrigin` are now, respectively, just aliases for `Origin[False]` and `Origin[True]`. * `Origin` struct values are now supported in the origin specifier of a `ref [..]` argument. * Added `Origin.cast_from` for casting the mutability of an origin value. * `ref` arguments and results now allow for providing a memory value directly in the origin specifier, rather than requiring the use of `__origin_of()`. It is still fine to use `__origin_of()` explicitly though, and this is required when specifying origins for parameters (e.g. to the `Pointer` type). For example, this is now valid without `__origin_of()`: ```mojo fn return_ref(a: String) -> ref [a] String: return a ``` * Various improvements to origin handling and syntax have landed, including support for the ternary operator and allowing multiple arguments in a `ref` specifier (which are implicitly unions). This enables expression of simple algorithms cleanly: ```mojo fn my_min[T: Comparable](ref a: T, ref b: T) -> ref [a, b] T: return a if a < b else b ``` It is also nice that `my_min` automatically and implicitly propagates the mutability of its arguments, so things like `my_min(str1, str2) += "foo"` is valid. * `ref` function arguments without an origin clause are now treated as `ref [_]`, which is more syntactically convenient and consistent: ```mojo fn takes_and_return_ref(ref a: String) -> ref [a] String: return a ``` * The `__type_of(x)` and `__origin_of(x)` operators are much more general now: they allow arbitrary expressions inside of them, allow referring to dynamic values in parameter contexts, and even allow referring to raising functions in non-raising contexts. These operations never evaluate their expression, so any side effects that occur in the expression are never evaluated at runtime, eliminating concerns about `__type_of(expensive())` being a problem. * The destructor insertion logic in Mojo is now aware that types that take an `MutableAnyOrigin` or `ImmutableAnyOrigin` as part of their signature could potentially access any live value that destructor insertion is tracking, eliminating a significant usability issue with unsafe APIs like [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer). Consider a typical example working with strings before this change: ```mojo var str = String(...) var ptr = str.unsafe_ptr() some_low_level_api(ptr) _ = str^ # OLD HACK: Explicitly keep string alive until here! ``` The `_ = str^` pattern was formerly required because the Mojo compiler has no idea what "ptr" might reference. As a consequence, it had no idea that `some_low_level_api()` might access `str` and therefore thought it was ok to destroy the `String` before the call - this is why the explicit lifetime extension was required. Mojo now knows that [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) may access the `MutableAnyOrigin` origin, and now assumes that any API that uses that origin could use live values. In this case, it assumes that `some_low_level_api()` might access `str` and because it might be using it, it cannot destroy `str` until after the call. The consequence of this is that the old hack is no longer needed for these cases! * Function types now accept an origin set parameter. This parameter represents the origins of values captured by a parameter closure. The compiler automatically tags parameter closures with the right set of origins. This enables lifetimes and parameter closures to correctly compose. ```mojo fn call_it[f: fn() capturing [_] -> None](): f() fn test(): var msg = String("hello world") @parameter fn say_hi(): print(msg) call_it[say_hi]() # no longer need to write `_ = msg^`!! ``` Note that this only works for higher-order functions which have explicitly added `[_]` as the capture origins. By default, the compiler still assumes a `capturing` closure does not reference any origins. This will soon change. * Infer-only parameters may now be explicitly bound with keywords, enabling some important patterns in the standard library: ```mojo struct StringSlice[is_mutable: Bool, //, origin: Origin[is_mutable]]: ... alias ImmStringSlice = StringSlice[is_mutable=False] # This auto-parameterizes on the origin, but constrains it to being an # immutable slice instead of a potentially mutable one. fn take_imm_slice(a: ImmStringSlice): ... ``` * The flag for turning on asserts has changed, e.g. to enable all checks: ```bash mojo -D ASSERT=all main.mojo ``` The levels are: * `none`: all assertions off * `warn`: print assertion errors e.g. for multithreaded tests (previously `-D ASSERT_WARNING`) * `safe`: the default mode for standard CPU safety assertions * `all`: turn on all assertions (previously `-D MOJO_ENABLE_ASSERTIONS`) You can now also pass `Stringable` args to format a message, which will have no runtime penalty or IR bloat cost when assertions are off. Previously you had to: ```mojo x = -1 debug_assert( x > 0, String.format_sequence(“expected x to be more than 0 but got: ”, x) ) ``` Which can't be optimized away by the compiler in release builds, you can now pass multiple args for a formatted message at no runtime cost: ```mojo debug_assert(x > 0, “expected x to be more than 0 but got: ”, x) ``` * Automatic parameterization of parameters is now supported. Specifying a parameterized type with unbound parameters causes them to be implicitly added to the function signature as infer-only parameters. ```mojo fn foo[value: SIMD[DType.int32, _]](): pass # Equivalent to fn foo[size: Int, //, value: SIMD[DType.int32, size]](): pass ``` * Mojo can now interpret simple LLVM intrinsics in parameter expressions, enabling things like `count_leading_zeros` to work at compile time: [Issue #933](https://github.com/modular/modular/issues/933). * Introduced the `@explicit_destroy` annotation, the `__disable_del` keyword, the `UnknownDestructibility` trait, and the `ImplicitlyDestructible` keyword, for the experimental explicitly destroyed types feature. * Added associated types; we can now have aliases like `alias T: AnyType`, `alias N: Int`, etc. in a trait, and then specify them in structs that conform to that trait. For more information, see [Associated aliases for generics](/mojo/manual/traits#associated-aliases-for-generics). ### Standard library changes {#24-6-standard-library-changes} * Introduced a new [`Deque`](/mojo/std/collections/deque/Deque) (double-ended queue) collection type, based on a dynamically resizing circular buffer for efficient O(1) additions and removals at both ends as well as O(1) direct access to all elements. The `Deque` supports the full Python `collections.deque` API, ensuring that all expected deque operations perform as in Python. Enhancements to the standard Python API include `peek()` and `peekleft()` methods for non-destructive access to the last and first elements, and advanced constructor options (`capacity`, `min_capacity`, and `shrink`) for customizing memory allocation and performance. These options allow for optimized memory usage and reduced buffer reallocations, providing flexibility based on application requirements. * The `Formatter` struct has been replaced with a [`Writer`](/mojo/std/utils/write/Writer) trait to enable buffered IO, increasing print and file writing perf to the same speed as C. It's now more general purpose and can write any `Span[Byte]`. To align with this the `Formattable` trait is now named [`Writable`](/mojo/std/utils/write/Writable), and the `String.format_sequence()` static method to initialize a new `String` has been renamed to [`String.write()`](/mojo/std/collections/string/string/String/#write). Here's an example of using all of the changes: ```mojo from memory import Span @value struct NewString(Writer, Writable): var s: String # Writer requirement to write a Span of Bytes fn write_bytes(inout self, bytes: Span[Byte, _]): self.s._iadd[False](bytes) # Writer requirement to take multiple args fn write[*Ts: Writable](inout self, *args: *Ts): @parameter fn write_arg[T: Writable](arg: T): arg.write_to(self) args.each[write_arg]() # Also make it Writable to allow `print` to write the inner String fn write_to[W: Writer](self, inout writer: W): writer.write(self.s) @value struct Point(Writable): var x: Int var y: Int # Pass multiple args to the Writer. The Int and StringLiteral types call # `writer.write_bytes` in their own `write_to` implementations. fn write_to[W: Writer](self, inout writer: W): writer.write("Point(", self.x, ", ", self.y, ")") # Enable conversion to a String using `str(point)` fn __str__(self) -> String: return String.write(self) fn main(): var point = Point(1, 2) var new_string = NewString(str(point)) new_string.write("\n", Point(3, 4)) print(new_string) ``` ```output Point(1, 2) Point(3, 4) ``` * The `TypeIdentifiable` trait has been removed in favor of the new `get_type_name` utility in the `compile.reflection` module. * Python interop changes: * Introduced `TypedPythonObject` as a light-weight way to annotate [`PythonObject`](/mojo/std/python/python_object/PythonObject) values with static type information. This design will likely evolve and change significantly. * Added `TypedPythonObject[Tuple].__getitem__()` for accessing the elements of a Python tuple. * Added [`Python.add_object()`](/mojo/std/python/python/Python#add_object), to add a named `PythonObject` value to a Python 'module' object instance. * Added [`Python.unsafe_get_python_exception()`](/mojo/std/python/python/Python#unsafe_get_python_exception), as an efficient low-level utility to get the Mojo `Error` equivalent of the current CPython error state. * Add [`PythonObject.from_borrowed_ptr()`](/mojo/std/python/python_object/PythonObject#from_borrowed_ptr), to simplify the construction of `PythonObject` values from CPython 'borrowed reference' pointers. The existing `PythonObject.__init__(PyObjectPtr)` should continue to be used for the more common case of constructing a `PythonObject` from a 'strong reference' pointer. * Support for multi-dimensional indexing and slicing for `PythonObject` (PR [#3549](https://github.com/modular/modular/pull/3549), PR [#3583](https://github.com/modular/modular/pull/3583)). ```mojo var np = Python.import_module("numpy") var a = np.array(PythonObject([1,2,3,4,5,6])).reshape(2,3) print((a[0, 1])) # 2 print((a[1][::-1])) # [6 5 4] ``` Note that the syntax, `a[1, ::-1]`, is currently not supported. * Added [`PythonObject.__contains__()`](/mojo/std/python/python_object/PythonObject#__contains__). ([PR #3101](https://github.com/modular/modular/pull/3101)) Example usage: ```mojo x = PythonObject([1,2,3]) if 1 in x: print("1 in x") ``` * Pointer related changes: * The [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) type now has an `origin` parameter that can be used when the `UnsafePointer` points to a value with a known origin. This origin is propagated through the `ptr[]` indirection operation. This parameter and other `UnsafePointer` parameters (other than the type) are now keyword-only. * You can now index into `UnsafePointer` using `SIMD` scalar integral types: ```mojo p = UnsafePointer[Int].alloc(1) i = UInt8(1) p[i] = 42 print(p[i]) ``` * Added a new [`OwnedPointer`](/mojo/std/memory/owned_pointer/OwnedPointer) type as a safe, single-owner, non-nullable smart pointer with similar semantics to Rust's [`Box<>`](https://doc.rust-lang.org/std/boxed/struct.Box.html) and C++'s [`std::unique_ptr`](https://en.cppreference.com/w/cpp/memory/unique_ptr). ([PR #3524](https://github.com/modular/modular/pull/3524)) * `Arc` has been renamed to [`ArcPointer`](/mojo/std/memory/arc_pointer/ArcPointer), for consistency with `OwnedPointer`. * [`ArcPointer`](/mojo/std/memory/arc_pointer/ArcPointer) now implements [`Identifiable`](/mojo/std/builtin/identifiable/Identifiable), and can be compared for pointer equivalence using `a is b`. * The `Reference` type has been renamed to [`Pointer`](/mojo/std/memory/pointer/Pointer): a memory safe complement to `UnsafePointer`. This change is motivated by the fact that `Pointer` is assignable and requires an explicit dereference with `ptr[]`. Renaming to `Pointer` clarifies that "references" means `ref` arguments and results, and gives us a model that is more similar to what the C++ community would expect. For an overview of Mojo's pointer types, see the new [Intro to pointers](/mojo/manual/pointers/) page in the Mojo Manual. * A new [`as_noalias_ptr()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#as_noalias_ptr) method as been added to `UnsafePointer`. This method specifies to the compiler that the resultant pointer is a distinct identifiable object that does not alias any other memory in the local scope. * Added the [`Floatable`](/mojo/std/builtin/floatable/Floatable) and [`FloatableRaising`](/mojo/std/builtin/floatable/FloatableRaising) traits to denote types that can be converted to a `Float64` value using the builtin `float` function. Made `SIMD` and `FloatLiteral` conform to the `Floatable` trait. ([PR #3163](https://github.com/modular/modular/pull/3163)) ```mojo fn foo[F: Floatable](v: F): ... var f = float(Int32(45)) ``` * The [`rebind()`](/mojo/std/builtin/rebind/rebind) standard library function now works with memory-only types in addition to `@register_passable("trivial")` ones, without requiring a copy. For more information, see [The `rebind()` builtin](/mojo/manual/parameters/#the-rebind-builtin) in the Mojo Manual. * Introduced the [`random.shuffle()`](/mojo/std/random/random/shuffle) function for randomizing the elements of a `List`. ([PR #3327](https://github.com/modular/modular/pull/3327)) Example: ```mojo from random import shuffle var l = List[Int](1, 2, 3, 4, 5) shuffle(l) ``` * The [`Dict.__getitem__()`](/mojo/std/collections/dict/Dict#__getitem__) method now returns a reference instead of a copy of the value (or raises). This improves the performance of common code that uses `Dict` by allowing borrows from the `Dict` elements. * [`Slice.step`](/mojo/std/builtin/builtin_slice/Slice#fields) is now an `Optional[Int]`, matching the optionality of `slice.step` in Python. ([PR #3160](https://github.com/modular/modular/pull/3160)) * There is now a [`Byte`](/mojo/std/builtin/simd/#aliases) alias to better express intent when working with a pack of bits. ([PR #3670](https://github.com/modular/modular/pull/3670)). * Expanded [`os.path`](/mojo/std/os/path/path/) with new functions: * `os.path.expandvars()`: Expands environment variables in a path ([PR #3735](https://github.com/modular/modular/pull/3735)). * `os.path.splitroot()`: Split a path into drive, root and tail. ([PR #3780](https://github.com/modular/modular/pull/3780)). * Added a [`reserve()`](/mojo/std/collections/string/string/String#reserve) method and new constructor to the `String` struct to allocate additional capacity. ([PR #3755](https://github.com/modular/modular/pull/3755)). * A new [`StringLiteral.get[some_stringable]()`](/mojo/std/builtin/string_literal/StringLiteral#get) method is available. It allows forming a runtime-constant `StringLiteral` from a compile-time-dynamic `Stringable` value. * [`Span`](/mojo/std/memory/span/Span) has moved from the `utils` module to the `memory` module. * [`Span`](/mojo/std/memory/span/Span) now implements `__reversed__()`. This means that one can get a reverse iterator over a `Span` using `reversed(my_span)`. Users should currently prefer this method over `my_span[::-1]`. * A new `AsBytes` trait has been added to enable taking a `Span[Byte]` from any type that implements `as_bytes()`. `String.as_bytes()` and `String.as_bytes_slice()` have been consolidated under `String.as_bytes()` to return a `Span[Byte]`. If you require a copy, you can convert the `Span` to a `List` with `List(my_string.as_bytes())`. * [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice) now implements `strip()`, `rstrip()`, and `lstrip()`. * [`StringRef`](/mojo/std/collections/string/string_slice/StringSlice) now implements `split()` which can be used to split a `StringRef` into a `List[StringRef]` by a delimiter. ([PR \#2705](https://github.com/modular/modular/pull/2705)) * [`StringRef`](/mojo/std/collections/string/string_slice/StringSlice) is now representable so `repr(StringRef("hello"))` will return `StringRef('hello')`. * More things have been removed from the auto-exported set of entities in the `prelude` module from the Mojo standard library: * `UnsafePointer` has been removed. Please explicitly import it via `from memory import UnsafePointer`. * `StringRef` has been removed. Please explicitly import it via `from utils import StringRef`. * Restored implicit copyability of [`Tuple`](/mojo/std/builtin/tuple/Tuple) and `ListLiteral`. * The [aliases for C foreign function interface (FFI)](/mojo/std/sys/ffi/#aliases) have been renamed: `C_int` -> `c_int`, `C_long` -> `c_long` and so on. * `Float32` and `Float64` are now printed and converted to strings with roundtrip guarantee and shortest representation: ```plaintext Value Old New Float64(0.3) 0.29999999999999999 0.3 Float32(0.3) 0.30000001192092896 0.3 Float64(0.0001) 0.0001 0.0001 Float32(0.0001) 9.9999997473787516e-05 0.0001 Float64(-0.00001) -1.0000000000000001e-05 -1e-05 Float32(-0.00001) -9.9999997473787516e-06 -1e-05 Float32(0.00001234) 1.2339999557298142e-05 1.234e-05 Float32(-0.00000123456) -1.2345600453045336e-06 -1.23456e-06 Float64(1.1234567e-320) 1.1235052786429946e-320 1.1235e-320 Float64(1.234 * 10**16) 12340000000000000.0 1.234e+16 ``` * The `StaticIntTuple` data structure in the `utils` package has been renamed to [`IndexList`](/mojo/std/utils/index_/IndexList). The data structure now allows one to specify the index bitwidth of the elements along with whether the underlying indices are signed or unsigned. * Added [`DLHandle.get_symbol()`](/mojo/std/sys/ffi/DLHandle#get_symbol), for getting a pointer to a symbol in a dynamic library. This is more general purpose than the existing methods for getting function pointers. ### Tooling changes {#24-6-tooling-changes} * The VS Code Mojo Debugger now has a `buildArgs` JSON debug configuration setting that can be used in conjunction with `mojoFile` to define the build arguments when compiling the Mojo file. * The VS Code extension now supports a `Configure Build and Run Args` command that helps set the build and run args for actions file `Run Mojo File` and `Debug Mojo File`. A corresponding button appears in `Run and Debug` selector in the top right corner of a Mojo File. * The VS Code extension now has the `mojo.run.focusOnTerminalAfterLaunch` setting, which controls whether to focus on the terminal used by the `Mojo: Run Mojo File` command or on the editor after launch. [Issue #3532](https://github.com/modular/modular/issues/3532). * The VS Code extension now has the `mojo.SDK.additionalSDKs` setting, which allows the user to provide a list of MAX SDKs that the extension can use when determining a default SDK to use. The user can select the default SDK to use with the `Mojo: Select the default MAX SDK` command. * The VS Code extension now supports setting [data breakpoints](https://code.visualstudio.com/docs/editor/debugging#_data-breakpoints) as well as [function breakpoints](https://code.visualstudio.com/docs/editor/debugging#_function-breakpoints). * The Mojo LLDB debugger now supports symbol breakpoints, for example, `b main` or `b my_module::main`. * Error messages that include type names no longer include inferred or defaulted parameters when they aren't needed. For example, previously Mojo complained about things like: ```plaintext ... cannot be converted from 'UnsafePointer[UInt, 0, _default_alignment::AnyType](), MutableAnyOrigin]' to 'UnsafePointer[Int, 0, _default_alignment[::AnyType](), MutableAnyOrigin]' ``` it now complains more helpfully that: ```plaintext ... cannot be converted from 'UnsafePointer[UInt]' to 'UnsafePointer[Int]' ``` * Tooling now prints the origins of `ref` arguments and results correctly, and prints `self` instead of `self: Self` in methods. * The Mojo Language Server and generated documentation now print parametric result types correctly, e.g. showing `SIMD[type, simd_width]` instead of `SIMD[$0, $1]`. * Generated API documentation now shows the signatures for structs, and identifies `@register_passable` and `@register_passable("trivial")` types. * The VS Code extension now allows cancelling the installation of its private MAX SDK. * The VS Code extension now opens the Run and Debug tab automatically whenever a debug session starts. * The `mojo debug --vscode` command now support the `--init-command` and `--stop-on-entry` flags. Execute `mojo debug --help` for more information. * The Mojo LLDB debugger on VS Code now supports inspecting the raw attributes of variables that are handled as synthetic types, e.g. `List` from Mojo or `std::vector` from C++. * The VS Code extension now allows selecting a default SDK when multiple are available. ### ❌ Removed * The `UnsafePointer.bitcast()` overload for `DType` has been removed. Wrap your `DType` in a `Scalar[my_dtype]` to call the only overload of `bitcast()` now. ### 🛠️ Fixed * Lifetime tracking is now fully field sensitive, which makes the uninitialized variable checker more precise. * [Issue #1310](https://github.com/modular/modular/issues/1310) - Mojo permits the use of any constructor for implicit conversions * [Issue #1632](https://github.com/modular/modular/issues/1632) - Mojo produces weird error when inout function is used in non mutating function * [Issue #3444](https://github.com/modular/modular/issues/3444) - Raising init causing use of uninitialized variable * [Issue #3544](https://github.com/modular/modular/issues/3544) - Known mutable `ref` argument are not optimized as `noalias` by LLVM. * [Issue #3559](https://github.com/modular/modular/issues/3559) - VariadicPack doesn't extend the lifetimes of the values it references. * [Issue #3627](https://github.com/modular/modular/issues/3627) - Compiler overlooked exclusivity violation caused by `ref [MutableAnyOrigin] T` * [Issue #3710](https://github.com/modular/modular/issues/3710) - Mojo frees memory while reference to it is still in use. * [Issue #3805](https://github.com/modular/modular/issues/3805) - Crash When Initializing !llvm.ptr. * [Issue #3816](https://github.com/modular/modular/issues/3816) - Ternary if-operator doesn't propagate origin information. * [Issue #3815](https://github.com/modular/modular/issues/3815) - \[BUG] Mutability not preserved when taking the union of two origins. * [Issue #3829](https://github.com/modular/modular/issues/3829) - Poor error message when invoking a function pointer upon an argument of the wrong origin * [Issue #3830](https://github.com/modular/modular/issues/3830) - Failures emitting register RValues to ref arguments. * The VS Code extension now auto-updates its private copy of the MAX SDK. * The variadic initializer for `SIMD` now works in parameter expressions. * The VS Code extension now downloads its private copy of the MAX SDK in a way that prevents `ETXTBSY` errors on Linux. * The VS Code extension now allows invoking a mojo formatter from SDK installations that contain white spaces in their path. ### Special thanks Special thanks to our community contributors: [@soraos](https://github.com/soraros), [@jjvraw](https://github.com/jjvraw), [@bgreni](https://github.com/bgreni), [@thatstoasty](https://github.com/thatstoasty), [@szbergeron](https://github.com/szbergeron), [@rd4com](https://github.com/rd4com), [@fknfilewalker](https://github.com/fknfilewalker), [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse), [@avitkauskas](https://github.com/avitkauskas), and [@martinvuyk](https://github.com/martinvuyk). ## v24.5 (2024-09-13) ### ✨ Highlights Here's a brief summary of some of the major changes in this release, with more detailed information in the following sections: * Mojo now supports Python 3.12 interoperability. * The set of automatically imported entities (types, aliases, functions) into users' Mojo programs has been dramatically reduced. This can break existing user code as users will need to explicitly import what they're using for cases previously automatically included before. * [`print()`](/mojo/std/builtin/io/print) now requires that its arguments conform to the [`Formattable`](/mojo/std/utils/write/Writable) trait. This enables efficient stream-based writing by default, avoiding unnecessary intermediate String heap allocations. * The new builtin [`input()`](/mojo/std/builtin/io/input) function prints an optional prompt and reads a line from standard input, in the same way as Python. * Mojo now allows implicit definitions of variables within a `fn` in the same way that has been allowed in a `def`. The `var` keyword is still allowed, but is now optional. * Mojo now diagnoses "argument exclusivity" violations due to aliasing references. Mojo requires references (including implicit references due to `borrowed`/`inout` arguments) to be uniquely referenced (non-aliased) if mutable. This is a warning in the 24.5 release, but will be upgraded to an error in subsequent releases. * Mojo now supports "conditional conformances" where some methods on a struct have additional trait requirements that the struct itself doesn't. * `DTypePointer`, `LegacyPointer`, and `Pointer` have been removed. Use [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) instead. Functions that previously took a `DTypePointer` now take an equivalent `UnsafePointer`. For more information on using pointers, see [Unsafe pointers](/mojo/manual/pointers/unsafe-pointers) in the Mojo Manual. * There are many new standard library APIs, with new features for strings, collections, and interacting with the filesystem and environment. Changes are listed in the standard library section. * The VS Code extension now supports a vendored MAX SDK for VS Code, which is automatically downloaded by the extension and it's used for all Mojo features, including the Mojo Language Server, the Mojo debugger, the Mojo formatter, and more. * [`mojo test`](/mojo/cli/test) now uses the Mojo compiler for running unit tests. This will resolve compilation issues that sometimes appeared, and will also improve overall test execution times. ### Language changes * Mojo now allows implicit definitions of variables within a `fn` in the same way that has been allowed in a `def`. The `var` keyword is still allowed and still denotes the declaration of a new variable with a scope (in both `def` and `fn`). Relaxing this makes `fn` and `def` more similar, but they still differ in other important ways. * Mojo now diagnoses "argument exclusivity" violations due to aliasing references. Mojo requires references (including implicit references due to `borrowed`/`inout` arguments) to be uniquely referenced (non-aliased) if mutable. This is important for code safety, because it allows the compiler (and readers of code) to understand where and when a value is mutated. It is also useful for performance optimization because it allows the compiler to know that accesses through immutable references cannot change behind the scenes. Here is an invalid example: ```mojo fn take_two_strings(a: String, inout b: String): # Mojo knows 'a' and 'b' cannot be the same string. b += a fn invalid_access(): var my_string = String() # warning: passing `my_string` inout is invalid since it is also passed # borrowed. take_two_strings(my_string, my_string) ``` This is similar to [Swift exclusivity checking](https://swift.org/blog/swift-5-exclusivity/) and the [Rust language](https://doc.rust-lang.org/beta/book/ch04-02-references-and-borrowing.html) sometimes known as "aliasing xor mutability". That said, the Mojo implementation details are somewhat different because lifetimes are embedded in types. This is a warning in the 24.5 release, but will be upgraded to an error in subsequent releases. :::note Argument exclusivity is not enforced for register-passable types. They are passed by copy, so they don't form aliases. ::: * Mojo now supports "conditional conformances" where some methods on a struct have additional trait requirements that the struct itself doesn't. This is expressed through an explicitly declared `self` type: ```mojo struct GenericThing[Type: AnyType]: # Works with anything # Sugar for 'fn normal_method[Type: AnyType](self: GenericThing[Type]):' fn normal_method(self): ... # Just redeclare the requirements with more specific types: fn needs_move[Type: Movable](self: GenericThing[Type], owned val: Type): var tmp = val^ # Ok to move 'val' since it is Movable ... fn usage_example(): var a = GenericThing[Int]() a.normal_method() # Ok, Int conforms to AnyType a.needs_move(42) # Ok, Int is movable var b = GenericThing[NonMovable]() b.normal_method() # Ok, NonMovable conforms to AnyType # error: argument type 'NonMovable' does not conform to trait 'Movable' b.needs_move(NonMovable()) ``` Conditional conformance works with dunder methods and other things as well. * As a specific form of "conditional conformances", initializers in a struct may indicate specific parameter bindings to use in the type of their `self` argument. For example: ```mojo @value struct MyStruct[size: Int]: fn __init__(inout self: MyStruct[0]): pass fn __init__(inout self: MyStruct[1], a: Int): pass fn __init__(inout self: MyStruct[2], a: Int, b: Int): pass def test(x: Int): a = MyStruct() # Infers size=0 from 'self' type. b = MyStruct(x) # Infers size=1 from 'self' type. c = MyStruct(x, x) # Infers size=2 from 'self' type. ``` * Mojo now supports named result bindings. Named result bindings are useful for directly emplacing function results into the output slot of a function. This feature provides more flexibility and guarantees around emplacing the result of a function compared to "guaranteed" named return value optimization (NRVO). If a `@register_passable` result is bound to a name, the result value is made accessible as a mutable reference. ```mojo fn efficiently_return_string(b: Bool) -> String as output: if b: output = "emplaced!" mutate(output) return return "regular return" ``` If we used a temporary for `output` instead, we would need to move into the result slot, which wouldn't work if the result type was non-movable. In a function with a named result, `return` may be used with no operand to signal an exit from the function, or it can be used normally to specify the return value of the function. The compiler will error if the result is not initialized on all normal exit paths from the function. * `__setitem__()` now works with variadic argument lists such as: ```mojo struct YourType: fn __setitem__(inout self, *indices: Int, val: Int): ... ``` The Mojo compiler now always passes the "new value" being set using the last keyword argument of the `__setitem__()`, e.g. turning `yourType[1, 2] = 3` into `yourType.__setitem__(1, 2, val=3)`. This fixes [Issue \#248](https://github.com/modular/modular/issues/248). * Mojo context managers used in regions of code that may raise no longer need to define a "conditional" exit function in the form of `fn __exit__(self, e: Error) -> Bool`. This function allows the context manager to conditionally intercept and handle the error and allow the function to continue executing. This is useful for some applications, but in many cases the conditional exit would delegate to the unconditional exit function `fn __exit__(self)`. Concretely, this enables defining `with` regions that unconditionally propagate inner errors, allowing code like: ```mojo def might_raise() -> Int: ... def foo() -> Int: with ContextMgr(): return might_raise() # no longer complains about missing return def bar(): var x: Int with ContextMgr(): x = might_raise() print(x) # no longer complains about 'x' being uninitialized ``` * `async` functions now support memory-only results (like `String`, `List`, etc.) and `raises`. Accordingly, both [`Coroutine`](/mojo/std/builtin/coroutine/Coroutine) and [`RaisingCoroutine`](/mojo/std/builtin/coroutine/RaisingCoroutine) have been changed to accept `AnyType` instead of `__TypeOfAllTypes`. This means the result types of `async` functions do not need to be `Movable`. ```mojo async fn raise_or_string(c: Bool) raises -> String: if c: raise "whoops!" return "hello world!" ``` Note that `async` functions do not yet support indirect calls, `ref` results, and constructors. * The [`Reference`](/mojo/std/memory/pointer/Pointer) type (and many iterators) now use [infer-only parameters](/mojo/manual/parameters/#infer-only-parameters) to represent the mutability of their lifetime, simplifying the interface. * The environment variable `MOJO_PYTHON` can be pointed to an executable to pin Mojo to a specific version: ```sh export MOJO_PYTHON="/usr/bin/python3.11" ``` Or a virtual environment to always have access to those Python modules: ```sh export MOJO_PYTHON="~/venv/bin/python" ``` `MOJO_PYTHON_LIBRARY` still exists for environments with a dynamic `libpython` but no Python executable. * The pointer aliasing semantics of Mojo have changed. Initially, Mojo adopted a C-like set of semantics around pointer aliasing and derivation. However, the C semantics bring a lot of history and baggage that are not needed in Mojo and which complicate compiler optimizations. The language overall provides a stronger set of invariants around pointer aliasing with lifetimes and exclusive mutable references to values, etc. It is now forbidden to convert a non-pointer-typed value derived from a Mojo-allocated pointer, such as an integer address, to a pointer-typed value. "Derived" means there is overlap in the bits of the non-pointer-typed value with the original pointer value. Accordingly, the [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) constructor that took an `address` keyword argument has been removed. It is still possible to make this conversion in certain cases where it is absolutely necessary, such as interoperating with other languages like Python. In this case, the compiler makes two assumptions: any pointer derived from a non-pointer-typed value does not alias any Mojo-derived pointer and that any external function calls have arbitrary memory effects. * `await` on a coroutine now consumes it. This strengthens the invariant that coroutines can be awaited only once. ### Standard library changes * [`builtin`](/mojo/std/builtin/) package: * The set of automatically imported entities (types, aliases, functions) into users' Mojo programs has been dramatically reduced. Before, with the way the `builtin` module was handled, all of the entities in the following modules would be automatically included: `memory`, `sys`, `os`, `utils`, `python`, `bit`, `random`, `math`, `builtin`, `collections` Now, only the explicitly enumerated entities in `prelude/__init__.mojo` are the ones automatically imported into users' Mojo programs. This will break a lot of user code as users will need to explicitly import what they're using for cases previously commonly included before (such as [`Optional`](/mojo/std/collections/optional/Optional), [`Variant`](/mojo/std/utils/variant/Variant), and functions such as [`abort()`](/mojo/std/os/os/abort), [`alignof()`](/mojo/std/sys/info/alignof), [`bitcast()`](/mojo/std/memory/unsafe/bitcast), [`bitwidthof()`](/mojo/std/sys/info/bitwidthof), [`external_call()`](/mojo/std/sys/ffi/external_call), [`simdwidthof()`](/mojo/std/sys/info/simdwidthof), and [`sizeof()`](/mojo/std/sys/info/sizeof)). * Some types from the `builtin` module have been moved to different modules for clarity which is made possible now that we have a `prelude` module that can re-export symbols from modules other than `builtin`. In particular, the `builtin.string` module has been moved to [`collections.string`](/mojo/std/collections/string/). * Input and output: * Added the builtin [`input()`](/mojo/std/builtin/io/input) function, which behaves the same as Python. ([PR #3392](https://github.com/modular/modular/pull/3392)) ```mojo name = input("Enter your name: ") print("Hello, " + name + "!") ``` If the user enters "Mojo" it returns "Hello, Mojo!" There is a known issue when running the `input()` function with JIT compilation (see issue [#3479](https://github.com/modular/modular/issues/3479)). * [`print()`](/mojo/std/builtin/io/print) now requires that its arguments conform to the [`Formattable`](/mojo/std/utils/write/Writable) trait. This enables efficient stream-based writing by default, avoiding unnecessary intermediate String heap allocations. Previously, `print()` required types conform to [`Stringable`](/mojo/std/builtin/str/Stringable). This meant that to execute a call like `print(a, b, c)`, at least three separate String heap allocations were down, to hold the formatted values of `a`, `b`, and `c` respectively. The total number of allocations could be much higher if, for example, `a.__str__()` was implemented to concatenate together the fields of `a`, like in the following example: ```mojo struct Point(Stringable): var x: Float64 var y: Float64 fn __str__(self) -> String: # Performs 3 allocations: 1 each for str(..) of each of the fields, # and then the final returned `String` allocation. return "(" + str(self.x) + ", " + str(self.y) + ")" ``` A type like the one above can transition to additionally implementing `Formattable` with the following changes: ```mojo struct Point(Stringable, Formattable): var x: Float64 var y: Float64 fn __str__(self) -> String: return String.format_sequence(self) fn format_to(self, inout writer: Formatter): writer.write("(", self.x, ", ", self.y, ")") ``` In the example above, [`String.format_sequence()`](/mojo/std/collections/string/string/String#format_sequence) is used to construct a `String` from a type that implements `Formattable`. This pattern of implementing a type's `Stringable` implementation in terms of its `Formattable` implementation minimizes boilerplate and duplicated code, while retaining backwards compatibility with the requirements of the commonly used `str()` function. :::note The error shown when passing a type that does not implement `Formattable` to `print()` is currently not entirely descriptive of the underlying cause: ```shell error: invalid call to 'print': callee with non-empty variadic pack argument expects 0 positional operands, but 1 was specified print(point) ~~~~~^~~~~~~ ``` If you see the above error, ensure that all argument types implement `Formattable`. ::: * [`debug_assert()`](/mojo/std/builtin/debug_assert/debug_assert) now also requires that its `message` argument conform to `Formattable`. * Added [`TemporaryDirectory`](/mojo/std/tempfile/tempfile/TemporaryDirectory) in module `tempfile`. ([PR 2743](https://github.com/modular/modular/pull/2743)) * Added [`NamedTemporaryFile`](/mojo/std/tempfile/tempfile/NamedTemporaryFile) in module `tempfile`. ([PR 2762](https://github.com/modular/modular/pull/2762)) * [`String`](/mojo/std/collections/string/string) and friends: * The `builtin.string` module has been moved to [`collections.string`](/mojo/std/collections/string/). * Added the [`String.format()`](/mojo/std/collections/string/string/String#format) method. ([PR #2771](https://github.com/modular/modular/pull/2771)) Supports automatic and manual indexing of `*args`. Examples: ```mojo print( String("{1} Welcome to {0} {1}").format("mojo", "🔥") ) # 🔥 Wecome to mojo 🔥 ``` ```mojo print(String("{} {} {}").format(True, 1.125, 2)) #True 1.125 2 ``` * [`String.format()`](/mojo/std/collections/string/string/String#format) now supports conversion flags `!s` and `!r`, allowing for `str()` and `repr()` conversions within format strings. ([PR \#3279](https://github.com/modular/modular/pull/3279)) Example: ```mojo String("{} {!r}").format("Mojo", "Mojo") # "Mojo 'Mojo'" String("{0!s} {0!r}").format("Mojo") # "Mojo 'Mojo'" ``` * The `String` class now has [`rjust()`](/mojo/std/collections/string/string/String#rjust), [`ljust()`](/mojo/std/collections/string/string/String#ljust), and [`center()`](/mojo/std/collections/string/string/String#center) methods to return a justified string based on width and fillchar. ([PR \#3278](https://github.com/modular/modular/pull/3278)) * The [`atol()`](/mojo/std/collections/string/string/atol) function now correctly supports leading underscores, (e.g.`atol("0x_ff", 0)`), when the appropriate base is specified or inferred (base 0). non-base-10 integer literals as per Python's [Integer Literals](https://docs.python.org/3/reference/lexical_analysis.html#integers). ([PR #3180](https://github.com/modular/modular/pull/3180)) * Added the [`unsafe_cstr_ptr()`](/mojo/std/collections/string/string/String#unsafe_cstr_ptr) method to `String` and `StringLiteral`, which returns an `UnsafePointer[c_char]` for convenient interoperability with C APIs. * Added the `byte_length()` method to [`String`](/mojo/std/collections/string/string/String#byte_length), [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice#byte_length), and [`StringLiteral`](/mojo/std/builtin/string_literal/StringLiteral#byte_length) and deprecated their private `_byte_length()` methods. Added a warning to the [`String.__len__()`](/mojo/std/collections/string/string/String#__len__) method that it will return the length in Unicode codepoints in the future and [`StringSlice.__len__()`](/mojo/std/collections/string/string_slice/StringSlice#__len__) now does return the Unicode codepoints length. ([PR \#2960](https://github.com/modular/modular/pull/2960)) * Added a new [`StaticString`](/mojo/std/collections/string/string_slice/#aliases) type alias. This can be used in place of [`StringLiteral`](/mojo/std/builtin/string_literal/StringLiteral) for runtime string arguments. * Added a [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice#__init__) initializer that accepts a `StringLiteral`. * The `StringRef` constructors from `DTypePointer.int8` have been changed to take a `UnsafePointer[c_char]`, reflecting their use for compatibility with C APIs. * Continued the transition to `UnsafePointer` and unsigned byte type for strings: * [`String.unsafe_ptr()`](/mojo/std/collections/string/string/String#unsafe_ptr) now returns an `UnsafePointer[UInt8]` (was `UnsafePointer[Int8]`) * [`StringLiteral.unsafe_ptr()`](/mojo/std/builtin/string_literal/StringLiteral#unsafe_ptr) now returns an `UnsafePointer[UInt8]` (was `UnsafePointer[Int8]`) * [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and other reference type changes: * `DTypePointer`, `LegacyPointer`, and `Pointer` have been removed. Use [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) instead. For more information on using pointers, see [Unsafe pointers](/mojo/manual/pointers/unsafe-pointers) in the Mojo Manual. Functions that previously took a `DTypePointer` now take an equivalent `UnsafePointer`. A quick rule for conversion from `DTypePointer` to `UnsafePointer` is: ```mojo DTypePointer[type] -> UnsafePointer[Scalar[type]] ``` There could be places that you have code of the form: ```mojo fn f(ptr: DTypePointer): ``` which is equivalent to `DTypePointer[*_]`. In this case you would have to add an infer-only `type` parameter to the function: ```mojo fn f[type: DType, //](ptr: UnsafePointer[Scalar[type]]): ``` because we can’t have an unbound parameter inside the struct. There could also be places where you use `DTypePointer[Scalar[DType.invalid/index]]`, and it would be natural to change these to `UnsafePointer[NoneType/Int]`. But since these are not an `UnsafePointer` that stores a `Scalar`, you might have to `rebind/bitcast` to appropriate types. * The `DTypePointer` [`load()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#load) and [`store()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#store) methods have been moved to `UnsafePointer`. * `UnsafePointer` now supports [`strided_load()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#strided_load), [`strided_store()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#strided_store), [`gather()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#gather), and [`scatter()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#scatter) when the underlying type is `Scalar[DType]`. * The global functions for working with `UnsafePointer` have transitioned to being methods through the use of conditional conformances: * `destroy_pointee(p)` => [`p.destroy_pointee()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#destroy_pointee) * `move_from_pointee(p)` => [`p.take_pointee()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#take_pointee) * `initialize_pointee_move(p, value)` => [`p.init_pointee_move(value)`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_move) * `initialize_pointee_copy(p, value)` => [`p.init_pointee_copy(value)`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_copy) * `move_pointee(src=p1, dst=p2)` => [`p.move_pointee_into(p2)`](/mojo/std/memory/unsafe_pointer/UnsafePointer#move_pointee_into) * The `UnsafePointer.offset()` method is deprecated and will be removed in a future release. Use [pointer arithmetic](/mojo/manual/pointers#storing-multiple-values) instead. ```mojo new_ptr = ptr.offset(1) ``` Becomes: ```mojo new_ptr = ptr + 1 ``` * `UnsafePointer` now has an [`alignment`](/mojo/std/memory/unsafe_pointer/UnsafePointer#parameters) parameter to specify the static alignment of the pointer. Consequently, [`UnsafePointer.alloc()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#alloc) no longer takes in an alignment parameter, and the alignment should be specified in the type. ```mojo UnsafePointer[type].alloc[alignment](x) # now becomes UnsafePointer[type, alignment].alloc(x) ``` * `UnsafePointer` has a new [`exclusive: Bool = False`](/mojo/std/memory/unsafe_pointer/UnsafePointer#parameters) parameter. Setting this parameter to true tells the compiler that the user knows this pointer and all those derived from it have exclusive access to the underlying memory allocation. The compiler is not guaranteed to do anything with this information. * It is no longer possible to cast (implicitly or explicitly) from `Reference` to `UnsafePointer`. Instead of `UnsafePointer(someRef)` please use the [`UnsafePointer.address_of(someRef[])`](/mojo/std/memory/unsafe_pointer/UnsafePointer#address_of) which makes the code explicit that the `UnsafePointer` gets the address of what the reference points to. * Python interoperability changes: * Mojo now supports Python 3.12 interoperability. * Creating a nested [`PythonObject`](/mojo/std/python/python_object/PythonObject) from a list or tuple of Python objects is possible now: ```mojo var np = Python.import_module("numpy") var a = np.array([1, 2, 3]) var b = np.array([4, 5, 6]) var arrays = PythonObject([a, b]) assert_equal(len(arrays), 2) ``` Also allowing more convenient call syntax: ```mojo var stacked = np.hstack((a, b)) assert_equal(str(stacked), "[1 2 3 4 5 6]") ``` ([PR #3264](https://github.com/modular/modular/pull/3264)) * Accessing local Python modules with [`Python.add_to_path(".")`](/mojo/std/python/python/Python#add_to_path) is no longer required. It now behaves the same as Python. You can access modules in the same folder as the target file: * `mojo run /tmp/main.mojo` can access `/tmp/mymodule.py` * `mojo build main.mojo -o ~/myexe && ~/myexe` can access `~/mymodule.py` * Collections: * [`List`](/mojo/std/collections/list/List) values are now equality comparable with `==` and `!=` when their element type is equality comparable. ([PR #3195](https://github.com/modular/modular/pull/3195)) * [`Optional`](/mojo/std/collections/optional/Optional) values are now equality comparable with `==` and `!=` when their element type is equality comparable. * Added a new [`Counter`](/mojo/std/collections/counter/Counter) dictionary-like type, matching most of the features of the Python one. ([PR #2910](https://github.com/modular/modular/pull/2910)) * [`Dict`](/mojo/std/collections/dict/Dict) now implements [`setdefault()`](/mojo/std/collections/dict/Dict#setdefault), which gets a value from the dictionary by key, or sets it to a default if it doesn't exist. ([PR #2803](https://github.com/modular/modular/pull/2803)) * `Dict` now supports [`popitem()`](/mojo/std/collections/dict/Dict#popitem), which removes and returns the last item in the `Dict`. ([PR #2701](https://github.com/modular/modular/pull/2701)) * Added a [`Dict.__init__()`](/mojo/std/collections/dict/Dict#__init__) overload to specify initial capacity. ([PR #3171](https://github.com/modular/modular/pull/3171)) The capacity has to be a power of two and greater than or equal to 8. It allows for faster initialization by skipping incremental growth steps. Example: ```mojo var dictionary = Dict[Int,Int](power_of_two_initial_capacity = 1024) # Insert (2/3 of 1024) entries ``` * `ListLiteral` now supports `__contains__()`. ([PR #3251](https://github.com/modular/modular/pull/3251)) * Filesystem and environment utilities: * [`Path.home()`](/mojo/std/pathlib/path/Path#home) has been added to return a path of the user's home directory. * [`os.path.expanduser()`](/mojo/std/os/path/path/expanduser) and [`pathlib.Path.exapanduser()`](/mojo/std/pathlib/path/Path#expanduser) have been added to allow expanding a prefixed `~` in a `String` or `Path` with the user's home path: ```mojo import os print(os.path.expanduser("~/.modular")) # /Users/username/.modular print(os.path.expanduser("~root/folder")) # /var/root/folder (on macos) # /root/folder (on linux) ``` * [`os.path.split()`](/mojo/std/os/path/path/split) has been added for splitting a path into `head, tail`: ```mojo import os head, tail = os.path.split("/this/is/head/tail") print("head:", head) print("tail:", tail) # head: /this/is/head # tail: tail ``` * [`os.makedirs()`](/mojo/std/os/os/makedirs) and [`os.removedirs()`](/mojo/std/os/os/removedirs) have been added for creating and removing nested directories: ```mojo import os path = os.path.join("dir1", "dir2", "dir3") os.path.makedirs(path, exist_ok=True) os.path.removedirs(path) ``` * The [`pwd`](/mojo/std/pwd/pwd/) module has been added for accessing user information in `/etc/passwd` on POSIX systems. This follows the same logic as Python: ```mojo import pwd import os current_user = pwd.getpwuid(os.getuid()) print(current_user) # pwd.struct_passwd(pw_name='jack', pw_passwd='********', pw_uid=501, # pw_gid=20, pw_gecos='Jack Clayton', pw_dir='/Users/jack', # pw_shell='/bin/zsh') print(current_user.pw_uid) # 501 root = pwd.getpwnam("root") print(root) # pwd.struct_passwd(pw_name='root', pw_passwd='*', pw_uid=0, pw_gid=0, # pw_gecos='System Administrator', pw_dir='/var/root', pw_shell='/bin/zsh') ``` * Other new traits and related features: * Added the `ExplicitlyCopyable` trait to mark types that can be copied explicitly, but which might not be implicitly copyable. This supports work to transition the standard library collection types away from implicit copyability, which can lead to unintended expensive copies. * Added the [`Identifiable`](/mojo/std/builtin/identifiable/Identifiable) trait, used to describe types that implement the `__is__()` and `__isnot__()` trait methods. ([PR #2807](https://github.com/modular/modular/pull/2807)) * Types conforming to [`Boolable`](/mojo/std/builtin/bool/Boolable) (that is, those implementing `__bool__()`) no longer implicitly convert to `Bool`. A new [`ImplicitlyBoolable`](/mojo/std/builtin/bool/ImplicitlyBoolable) trait is introduced for types where this behavior is desired. * Miscellaneous: * [`NoneType`](/mojo/std/builtin/none/NoneType) is now a normal standard library type, and not an alias for a raw MLIR type. Function signatures written as `fn() -> NoneType` should transition to being written as `fn() -> None`. * Mojo now has a [`UInt`](/mojo/std/builtin/uint/UInt) type for modeling unsigned (scalar) integers with a platform-dependent width. `UInt` implements most arithmetic operations that make sense for integers, with the notable exception of `__neg__()`. Builtin functions such as `min()`/`max()`, as well as `math` functions like `ceildiv()`, `align_down()`, and `align_up()` are also implemented for `UInt`. * Now that we have a `UInt` type, use this to represent the return type of a hash. In general, hashes should be an unsigned integer, and can also lead to improved performance in certain cases. * Added the [`c_char`](/mojo/std/sys/ffi/#aliases) type alias in `sys.ffi`. * [`sort()`](/mojo/std/builtin/sort/sort) now supports a `stable` parameter. It can be called by ```mojo sort[cmp_fn, stable=True](list) ``` The algorithm requires $O(N)$ auxiliary memory. If extra memory allocation fails, the program crashes. * `sort()` no longer takes `LegacyPointer` since that type is now removed. * Added the [`oct()`](/mojo/std/builtin/format_int/oct) builtin function for formatting an integer in octal. ([PR #2914](https://github.com/modular/modular/pull/2914)) * Added the [`assert_is()`](/mojo/std/testing/testing/assert_is) and [`assert_is_not()`](/mojo/std/testing/testing/assert_is_not) test functions to the `testing` module. * The [`math`](/mojo/std/math/constants/) package now includes the `pi`, `e`, and `tau` constants (Closes Issue [#2135](https://github.com/modular/modular/issues/2135)). * The [`ulp`](/mojo/std/math/math/ulp) function from `numerics` has been moved to the `math` module. * `bit` module now supports [`bit_reverse()`](/mojo/std/bit/bit/bit_reverse), [`byte_swap()`](/mojo/std/bit/bit/byte_swap), and [`pop_count()`](/mojo/std/bit/bit/pop_count) for the `Int` type. ([PR #3150](https://github.com/modular/modular/pull/3150)) * A few `bit` functions have been renamed for clarity: * `countl_zero()` -> [`count_leading_zeros()`](/mojo/std/bit/bit/count_leading_zeros) * `countr_zero()` -> [`count_trailing_zeros()`](/mojo/std/bit/bit/count_trailing_zeros) * [`Slice`](/mojo/std/builtin/builtin_slice/Slice) now uses `OptionalReg[Int]` for `start` and `end` and implements a constructor which accepts optional values. `Slice._has_end()` has also been removed since a Slice with no end is now represented by an empty `Slice.end` option. ([PR #2495](https://github.com/modular/modular/pull/2495)) ```mojo var s = Slice(1, None, 2) print(s.start.value()) # must retrieve the value from the optional ``` * The `rank` argument for [`algorithm.elementwise()`](/mojo/std/algorithm/functional/elementwise) is no longer required and is only inferred. * The `time.now()` function has been deprecated. Please use [`time.perf_counter()`](/mojo/std/time/time/perf_counter) or [`time.perf_counter_ns`](/mojo/std/time/time/perf_counter_ns) instead. * [`SIMD`](/mojo/std/builtin/simd/SIMD) construction from `Bool` has been restricted to `DType.bool` data type. ### Tooling changes * [`mojo test`](/mojo/cli/test) new features and changes: * `mojo test` now uses the Mojo compiler for running unit tests. This will resolve compilation issues that sometimes appeared, and will also improve overall test times, since we will only compile unit tests once before executing all of them. These changes do not apply to doctests, due to their different semantics. * The `mojo test` command now accepts a `--filter` option that will narrow the set of tests collected and executed. The filter string is a POSIX extended regular expression. * The `mojo test` command now supports using the same compilation options as `mojo build`. * You can now debug unit tests using `mojo test` by passing the `--debug` flag. Most debug flags are supported; run `mojo test --help` for a full listing. Debugging doctests is not currently supported. * Mojo debugger new features and changes: * The `mojo debug --rpc` command has been renamed to [`mojo debug --vscode`](/mojo/cli/debug#debug-server-options), which is now able to manage multiple VS Code windows. * The Mojo debugger now supports a `break-on-raise` command that indicated the debugger to stop at any `raise` statements. A similar features has been added to the debugger on VS Code. * The Mojo debugger now hides the artificial function arguments `__result__` and `__error__` created by the compiler for Mojo code. * VS Code support changes: * The VS Code extension now supports a vendored MAX SDK for VS Code, which is automatically downloaded by the extension and it's used for all Mojo features, including the Mojo Language Server, the Mojo debugger, the Mojo formatter, and more. * A proxy has been added to the Mojo Language Server on VS Code that handles crashes more gracefully. * The Mojo Language Server no longer sets `.` as a commit character for auto-completion. ### ❌ Removed * Support for the legacy `fn __init__(...) -> Self:` form has been removed from the compiler, please switch to using `fn __init__(inout self, ...):` instead. * The builtin `tensor` module has been removed. Identical functionality is available in `max.tensor`, but it is generally recommended to use structs from the [`buffer`](/mojo/std/buffer/buffer) module when possible instead. * Removed `String.unsafe_uint8_ptr()`. `String.unsafe_ptr()` now returns the same thing. * Removed `StringLiteral.unsafe_uint8_ptr()` and `StringLiteral.as_uint8_ptr()`. * Removed `SIMD.splat(value: Scalar[type])`. Use the constructor for `SIMD` instead. * Removed the `SIMD.{add,mul,sub}_with_overflow()` methods. * Removed the `SIMD.min()` and `SIMD.max()` methods. Identical functionality is available using the builtin [`min()`](/mojo/std/builtin/math/min) and [`max()`](/mojo/std/builtin/math/max) functions. * Removed the Mojo Language Server warnings for unused function arguments. * `Run Mojo File in Dedicated Terminal` action has been removed, and the action `Run Mojo File` will always open a dedicated terminal for each mojo file to guarantee a correct environment. ### 🛠️ Fixed * Fixed a crash in the Mojo Language Server when importing the current file. * Fixed crash when specifying variadic keyword arguments without a type expression in `def` functions, e.g.: ```mojo def foo(**kwargs): ... # now works ``` * Mojo now prints `ref` arguments and results in generated documentation correctly. * [#1734](https://github.com/modular/modular/issues/1734) - Calling `__copyinit__` on self causes crash. * [#3142](https://github.com/modular/modular/issues/3142) - \[QoI] Confusing `__setitem__` method is failing with a "must be mutable" error. * [#248](https://github.com/modular/modular/issues/248) - \[Feature] Enable `__setitem__` to take variadic arguments * [#3065](https://github.com/modular/modular/issues/3065) - Fix incorrect behavior of `SIMD.__int__` on unsigned types * [#3045](https://github.com/modular/modular/issues/3045) - Disable implicit SIMD conversion routes through `Bool` * [#3126](https://github.com/modular/modular/issues/3126) - \[BUG] List doesn't work at compile time. * [#3237](https://github.com/modular/modular/issues/3237) - \[BUG] Difference between `__getitem__` and `[.]` operator. * [#3336](https://github.com/modular/modular/issues/3336) - Fix outdated references to `let` in REPL documentation. * The VS Code extension no longer caches the information of the selected MAX SDK, which was causing issues upon changes in the SDK. * The Mojo debugger now stops showing spurious warnings when parsing closures. ### Special thanks Special thanks to our community contributors: [@jjvraw](https://github.com/jjvraw), [@artemiogr97](https://github.com/artemiogr97), [@martinvuyk](https://github.com/martinvuyk), [@jayzhan211](https://github.com/jayzhan211), [@bgreni](https://github.com/bgreni), [@mzaks](https://github.com/mzaks), [@msaelices](https://github.com/msaelices), [@rd4com](https://github.com/rd4com), [@jiex-liu](https://github.com/jiex-liu), [@kszucs](https://github.com/kszucs), [@thatstoasty](https://github.com/thatstoasty) ## v24.4 (2024-06-07) ### ✨ Highlights Big themes for this release: * Improvements to the performance and ease-of-use for `def` functions. * Continued unification of standard library APIs around the `UnsafePointer` type. * Many quality-of-life improvements for the standard library collection types. * Significant performance improvements when inserting into a `Dict`. Performance on this metric is still not where we'd like it to be, but it is much improved. * A new `@parameter for` mechanism for expressing compile-time loops, which replaces the earlier (and less reliable) `@unroll` decorator. * New Mojo Manual pages on [Control flow](/mojo/manual/control-flow), [Testing](/mojo/tools/testing) and using [unsafe pointers](/mojo/manual/pointers/unsafe-pointers). ### Language changes * Mojo has changed how `def` function arguments are processed. Previously, by default, arguments to a `def` were treated according to the `owned` convention, which makes a copy of the value, enabling that value to be mutable in the callee. This could lead to major performance issues because of the proliferation of unnecessary copies. It also required you to declare non-copyable types as `borrowed` explicitly. Now Mojo takes a different approach: `def` functions take arguments as `borrowed` by default (consistent with `fn` functions) but will make a local copy of the value **only if the argument is mutated** in the body of the function. This improves consistency, performance, and ease of use. * Implicit variable definitions in a `def` function are more flexible: you can now implicitly declare variables as the result of a tuple return, using `a,b,c = foo()`. For example: ```mojo def return_two(i: Int) -> (Int, Int): return i, i+1 a, b = return_two(5) ``` Implicit variable declarations can also now shadow global immutable symbols (such as module names and built-ins) without getting a compiler error. For example: ```mojo slice = foo() ``` * Mojo functions can return an auto-dereferenced reference to storage with a new `ref` keyword in the result type specifier. For example: ```mojo @value struct Pair: var first: Int var second: Int fn get_first_ref(inout self) -> ref [self] Int: return self.first fn show_mutation(): var somePair = Pair(5, 6) somePair.get_first_ref() = 1 ``` This approach provides a general way to return an "automatically dereferenced" reference of a given type. Notably, this eliminates the need for `__refitem__()` to exist. `__refitem__()` has thus been removed and replaced with `__getitem__()` that returns a reference. * Mojo added support for *infer-only parameters*. Infer-only parameters must appear at the beginning of the parameter list and cannot be explicitly specified by the user. They are declared to the left of a `//` marker, much like positional-only parameters. This allows programmers to define functions with dependent parameters to be called without the caller specifying all the necessary parameters. For example: ```mojo fn parameter_simd[dt: DType, //, value: Scalar[dt]](): print(value) fn call_it(): parameter_simd[Int32(42)]() ``` In the above example, `Int32(42)` is passed directly into `value`, the first parameter that isn't infer-only. `dt` is inferred from the parameter itself to be `DType.int32`. This also works with structs. For example: ```mojo struct ScalarContainer[dt: DType, //, value: Scalar[dt]]: pass fn foo(x: ScalarContainer[Int32(0)]): # 'dt' is inferred as `DType.int32` pass ``` This should make working with dependent parameters more ergonomic. See [Infer-only parameters](/mojo/manual/parameters/#infer-only-parameters) in the Mojo Manual. * Mojo now allows functions overloaded on parameters to be resolved when forming references to, but not calling, those functions. For example, the following now works: ```mojo fn overloaded_parameters[value: Int32](): pass fn overloaded_parameters[value: Float32](): pass fn form_reference(): alias ref = overloaded_parameters[Int32()] # works! ``` * Mojo now supports adding a `@deprecated` decorator on structs, functions, traits, aliases, and global variables. The decorator marks the attached declaration as deprecated and causes a warning to be emitted when the deprecated declaration is referenced in user code. The decorator requires a deprecation message, specified as a string literal. ```mojo @deprecated("Foo is deprecated, use Bar instead") struct Foo: pass fn outdated_api(x: Foo): # warning: Foo is deprecated, use Bar instead pass @deprecated("use another function!") fn bar(): pass fn techdebt(): bar() # warning: use another function! ``` * Mojo has introduced [`@parameter for`](/mojo/manual/decorators/parameter#parametric-for-statement), a new feature for compile-time programming. `@parameter for` defines a for loop where the sequence and the induction values in the sequence must be parameter values. For example: ```mojo fn parameter_for[max: Int](): @parameter for i in range(max) @parameter if i == 10: print("found 10!") ``` Currently, `@parameter for` requires the sequence's `__iter__()` method to return a `_StridedRangeIterator`, meaning the induction variables must be `Int`. The intention is to lift these restrictions in the future. * The `is_mutable` parameter of `Reference` and `AnyLifetime` is now a `Bool`, not a low-level `__mlir_type.i1` value. This improves the ergonomics of spelling out a `Reference` type explicitly. * Mojo will now link to a Python dynamic library based on the Python on top of your search path: `PATH`. This enables you to activate a virtual environment like `conda` and have access to Python modules installed in that environment without setting `MOJO_PYTHON_LIBRARY`. Previously Mojo would find a `libpython` dynamic library on installation and put the path in `.modular/modular.cfg`, which could result in version conflicts if you activated a virtual environment of a different Python version. * `AnyRegType` has been renamed to `__TypeOfAllTypes` and Mojo now forbids binding non-trivial register-passable types to `__TypeOfAllTypes`. This closes a major safety hole in the language. Please use `AnyType` for generic code going forward. * The `let` keyword has been completely removed from the language. We previously removed `let` declarations but still provided an error message to users. Now, it is completely gone from the grammar. ### Standard library changes * New traits and related features: * Added built-in [`repr()`](/mojo/std/builtin/repr/repr) function and [`Representable`](/mojo/std/builtin/repr/Representable) trait. ([PR #2361](https://github.com/modular/modular/pull/2361)) * Added the [`Indexer`](/mojo/std/builtin/int/Indexer) trait to denote types that implement the `__index__()` method which allows these types to be accepted in common `__getitem__()` and `__setitem__()` implementations, as well as allow a new built-in [`index()`](/mojo/std/builtin/int/index-function) function to be called on them. Most standard library containers can now be indexed by any type that implements `Indexer`. For example: ```mojo @value struct AlwaysZero(Indexer): fn __index__(self) -> Int: return 0 struct MyList: var data: List[Int] fn __init__(inout self): self.data = List[Int](1, 2, 3, 4) fn __getitem__[T: Indexer](self, idx: T) -> Int: return self.data[index(idx)] print(MyList()[AlwaysZero()]) # prints `1` ``` Types conforming to the `Indexer` trait are implicitly convertible to Int. This means you can write generic APIs that take `Int` instead of making them take a generic type that conforms to `Indexer`. For example: ```mojo @value struct AlwaysZero(Indexer): fn __index__(self) -> Int: return 0 @value struct Incrementer: fn __getitem__(self, idx: Int) -> Int: return idx + 1 var a = Incrementer() print(a[AlwaysZero()]) # works and prints 1 ``` ([PR #2685](https://github.com/modular/modular/pull/2685)) * Added traits allowing user-defined types to be supported by various built-in and math functions. | Function | Trait | Required method | | --------------------------------------------- | -------------------------------------------------- | --------------- | | [`abs()`](/mojo/std/builtin/math/abs) | [`Absable`](/mojo/std/builtin/math/Absable) | `__abs__()` | | [`pow()`](/mojo/std/builtin/math/pow) | [`Powable`](/mojo/std/builtin/math/Powable) | `__pow__()` | | [`round()`](/mojo/std/builtin/math/round) | [`Roundable`](/mojo/std/builtin/math/Roundable) | `__round__()` | | [`math.ceil`](/mojo/std/math/math/ceil) | `math.Ceilable` | `__ceil__()` | | [`math.ceildiv`](/mojo/std/math/math/ceildiv) | `math.CeilDivable` `math.CeilDivableRaising` | `__ceildiv__()` | | [`math.floor`](/mojo/std/math/math/floor) | `math.Floorable` | `__floor__()` | | [`math.trunc`](/mojo/std/math/math/trunc) | `Truncable` | `__trunc__()` | Notes: * Conforming to the `Powable` trait also means that the type can be used with the power operator (`**`). * For `ceildiv()`, structs can conform to either the `CeilDivable` trait or `CeilDivableRaising` trait. * Due to ongoing refactoring, the traits `Ceilable`, `CeilDivable`, `Floorable`, and `Truncable` do not appear in the API reference. They should be imported from the `math` module, except for `Truncable` which is (temporarily) available as a built-in trait and does not need to be imported. Example: ```mojo from math import sqrt @value struct Complex2(Absable, Roundable): var re: Float64 var im: Float64 fn __abs__(self) -> Self: return Self(sqrt(self.re * self.re + self.im * self.im), 0.0) fn __round__(self) -> Self: return Self(round(self.re, 0), round(self.im, 0)) fn __round__(self, ndigits: Int) -> Self: return Self(round(self.re, ndigits), round(self.im, ndigits)) ``` * Benchmarking: * The [`bencher`](/mojo/std/benchmark/bencher/) module as part of the `benchmark` package is now public and documented. This module provides types such as `Bencher` which provides the ability to execute a `Benchmark` and allows for benchmarking configuration via the `BenchmarkConfig` struct. * [`String`](/mojo/std/collections/string/string) and friends: * **Breaking.** Implicit conversion to `String` is now removed for builtin classes/types. Use `str()` explicitly to convert to `String`. * Added [`String.isspace()`](/mojo/std/collections/string/string/String#isspace) method conformant with Python's universal separators. This replaces the `isspace()` free function from the `string` module. (If you need the old function, it is temporarily available as `_isspace()`. It now takes a `UInt8` but is otherwise unchanged.) * [`String.split()`](/mojo/std/collections/string/string/String#split) now defaults to whitespace and has Pythonic behavior in that it removes all adjacent whitespace by default. * [`String.strip()`](/mojo/std/collections/string/string/String#strip), [`lstrip()`](/mojo/std/collections/string/string/String#lstrip) and [`rstrip()`](/mojo/std/collections/string/string/String#rstrip) can now remove custom characters other than whitespace. In addition, there are now several useful aliases for whitespace, ASCII lower/uppercase, and so on. ([PR #2555](https://github.com/modular/modular/pull/2555)) * `String` now has a [`splitlines()`](/mojo/std/collections/string/string/String#splitlines) method, which allows splitting strings at line boundaries. This method supports [universal newlines](https://docs.python.org/3/glossary.html#term-universal-newlines) and provides an option to retain or remove the line break characters. ([PR \#2810](https://github.com/modular/modular/pull/2810)) * `InlinedString` has been renamed to [`InlineString`](/mojo/std/collections/string/inline_string/InlineString) to be consistent with other types. * [`StringRef`](/mojo/std/collections/string/string_slice/StringSlice) now implements [`strip()`](/mojo/std/collections/string/string_slice/StringSlice#strip), which can be used to remove leading and trailing whitespace. ([PR \#2683](https://github.com/modular/modular/pull/2683)) * `StringRef` now implements [`startswith()`](/mojo/std/collections/string/string_slice/StringSlice#startswith) and [`endswith()`](/mojo/std/collections/string/string_slice/StringSlice#endswith). ([PR #2710](https://github.com/modular/modular/pull/2710)) * Added a new [`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice) type, to replace uses of the unsafe `StringRef` type in standard library code. `StringSlice` is a non-owning reference to encoded string data. Unlike `StringRef`, a `StringSlice` is safely tied to the lifetime of the data it points to. * Added new [`as_string_slice()`](/mojo/std/collections/string/string/String#as_string_slice) methods to `String` and `StringLiteral`. * Added `StringSlice` initializer from an `UnsafePointer` and a length in bytes. * Added a new [`as_bytes_slice()`](/mojo/std/collections/string/string/String#as_bytes_slice) method to `String` and `StringLiteral`, which returns a `Span` of the bytes owned by the string. * Continued transition to [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and unsigned byte type for strings: * Renamed `String._as_ptr()` to [`String.unsafe_ptr()`](/mojo/std/collections/string/string/String#unsafe_ptr), and changed return type to `UnsafePointer` (was `DTypePointer`). * Renamed `StringLiteral.data()` to [`StringLiteral.unsafe_ptr()`](/mojo/std/builtin/string_literal/StringLiteral#unsafe_ptr), and changed return type to `UnsafePointer` (was `DTypePointer`). * `InlineString.as_ptr()` has been renamed to [`unsafe_ptr()`](/mojo/std/collections/string/inline_string/InlineString#unsafe_ptr) and now returns an `UnsafePointer[UInt8]` (was `DTypePointer[DType.int8]`). * `StringRef.data` is now an `UnsafePointer` (was `DTypePointer`) and [`StringRef.unsafe_ptr()`](/mojo/std/collections/string/string_slice/StringSlice#unsafe_ptr) now returns an `UnsafePointer[UInt8]` (was `DTypePointer[DType.int8]`). * Other built-ins: * The `Slice.__len__()` function has been removed and [`Slice`](/mojo/std/builtin/builtin_slice/Slice) no longer conforms to the `Sized` trait. This clarifies the ambiguity of the semantics: the length of a slice always depends on the length of the object being sliced. Users that need the existing functionality can use the [`Slice.unsafe_indices()`](/mojo/std/builtin/builtin_slice/Slice#indices) method. This makes it explicit that this implementation does not check if the slice bounds are concrete or within any given object's length. * Added a built-in [`sort()`](/mojo/std/builtin/sort/sort) function for lists of elements that conform to the [`ComparableCollectionElement`](/mojo/std/builtin/value/ComparableCollectionElement) trait.([PR #2609](https://github.com/modular/modular/pull/2609)) * `int()` can now take a string and a specified base to parse an integer from a string: `int("ff", 16)` returns `255`. Additionally, if a base of zero is specified, the string will be parsed as if it was an integer literal, with the base determined by whether the string contains the prefix `"0x"`, `"0o"`, or `"0b"`. ([PR #2273](https://github.com/modular/modular/pull/2273), fixes [#2274](https://github.com/modular/modular/issues/2274)) * Added the [`bin()`](/mojo/std/builtin/format_int/bin) built-in function to convert integral types into their binary string representation. ([PR #2603](https://github.com/modular/modular/pull/2603)) * Added the [`atof()`](/mojo/std/collections/string/string/atof) built-in function, which can convert a `String` to a `float64`. ([PR \#2649](https://github.com/modular/modular/pull/2649)) * You can now use the built-in [`any()`](/mojo/std/builtin/bool/any) and [`all()`](/mojo/std/builtin/bool/all) functions to check for truthy elements in a collection. Because `SIMD.__bool__()` is now constrained to `size=1`, You must explicitly use these to get the truthy value of a SIMD vector with more than one element. This avoids common bugs around implicit conversion of `SIMD` to `Bool`. ([PR #2600](https://github.com/modular/modular/pull/2600)) For example: ```mojo fn truthy_simd(): var vec = SIMD[DType.int32, 4](0, 1, 2, 3) if any(vec): print("any elements are truthy") if all(vec): print("all elements are truthy") ``` * `object` now implements all the bitwise operators. ([PR #2324](https://github.com/modular/modular/pull/2324)) * [`Tuple`](/mojo/std/builtin/tuple/Tuple) now supports `__contains__()`. ([PR #2709](https://github.com/modular/modular/pull/2709)) For example: ```mojo var x = Tuple(1, 2, True) if 1 in x: print("x contains 1") ``` * `ListLiteral` and `Tuple` now only require that element types be `Movable`. Consequently, `ListLiteral` and `Tuple` are themselves no longer `Copyable`. * Added new `ImmutableStaticLifetime` and `MutableStaticLifetime` helpers. * [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and others: * Added new [`memcpy()`](/mojo/std/memory/memory/memcpy) overload for `UnsafePointer[Scalar[_]]` pointers. * Removed the `get_null()` method from `UnsafePointer` and other pointer types. Please use the default constructor instead: `UnsafePointer[T]()`. * Many functions returning a pointer type have been unified to have a public API function of `unsafe_ptr()`. * The `Tensor.data()` method has been renamed to `unsafe_ptr()`. The return type is still a `DTypePointer[T]`. * Collections: * [`List`](/mojo/std/collections/list/List) now has an [`index()`](/mojo/std/collections/list/List#index) method that allows you to find the (first) location of an element in a `List` of `EqualityComparable` types. For example: ```mojo var my_list = List[Int](2, 3, 5, 7, 3) print(my_list.index(3)) # prints 1 ``` * `List` can now be converted to a `String` with a simplified syntax: ```mojo var my_list = List[Int](2, 3) print(my_list.__str__()) # prints [2, 3] ``` Note that `List` doesn't conform to the `Stringable` trait yet so you cannot use `str(my_list)` yet. ([PR #2673](https://github.com/modular/modular/pull/2673)) * `List` has a simplified syntax to call the [`count()`](/mojo/std/collections/list/List#count) method: `my_list.count(x)`. ([PR #2675](https://github.com/modular/modular/pull/2675)) * `List()` now supports `__contains__()`, so you can now use lists with the `in` operator: ```mojo if x in my_list: ``` ([PR #2667](https://github.com/modular/modular/pull/2667)) * `List` now has an [`unsafe_get()`](/mojo/std/collections/list/List#unsafe_get) to get the reference to an element without bounds check or wraparound for negative indices. Note that this method is unsafe. Use with caution. [PR #2800](https://github.com/modular/modular/pull/2800) * Added a [`fromkeys()`](/mojo/std/collections/dict/Dict#fromkeys) method to `Dict` to return a `Dict` with the specified keys and values. ([PR 2622](https://github.com/modular/modular/pull/2622)) * Added a [`clear()`](/mojo/std/collections/dict/Dict#clear) method to `Dict`. ([PR 2627](https://github.com/modular/modular/pull/2627)) * `Dict` now supports [`reversed()`](/mojo/std/builtin/reversed/reversed) for its `items()` and `values()` iterators. ([PR #2340](https://github.com/modular/modular/pull/2340)) * `Dict` now has a simplified conversion to `String` with `my_dict.__str__()`. Note that `Dict` does not conform to the `Stringable` trait so `str(my_dict)` is not possible yet. ([PR #2674](https://github.com/modular/modular/pull/2674)) * `Dict` now implements [`get(key)`](/mojo/std/collections/dict/Dict#get) and `get(key, default)` functions. ([PR #2519](https://github.com/modular/modular/pull/2519)) * Added a temporary `__get_ref(key)` method to `Dict`, allowing you to get a `Reference` to a dictionary value. * Added a new [`InlineList`](/mojo/std/collections/inline_array/InlineArray) type, a stack-allocated list with a static maximum size. ([PR 2587#](https://github.com/modular/modular/pull/2587)) ([PR #2703](https://github.com/modular/modular/pull/2703)) * Added a new [`Span`](/mojo/std/memory/span/Span) type for taking slices of contiguous collections. ([PR \#2595](https://github.com/modular/modular/pull/2595)) * [`os`](/mojo/std/os/os/) module: * The `os` module now provides functionality for adding and removing directories using [`mkdir()`](/mojo/std/os/os/mkdir) and [`rmdir()`](/mojo/std/os/os/rmdir). ([PR #2430](https://github.com/modular/modular/pull/2430)) * Added the [`os.path.getsize()`](/mojo/std/os/path/path/getsize) function, which gives the size in bytes of the file identified by the path. ([PR 2626](https://github.com/modular/modular/pull/2626)) * Added [`os.path.join()`](/mojo/std/os/path/path/join) function. ([PR 2792](https://github.com/modular/modular/pull/2792)) * Added a new [`tempfile`](/mojo/std/tempfile/tempfile/) module, with `gettempdir()` and `mkdtemp()` functions. ([PR 2742](https://github.com/modular/modular/pull/2742)) * [`SIMD`](/mojo/std/builtin/simd/SIMD) type: * Added [`SIMD.shuffle()`](/mojo/std/builtin/simd/SIMD#shuffle) with `IndexList` mask. ([PR #2315](https://github.com/modular/modular/pull/2315)) * [`SIMD.__bool__()`](/mojo/std/builtin/simd/SIMD#__bool__) is constrained such that it only works when `size` is `1`. For SIMD vectors with more than one element, use [`any()`](/mojo/std/builtin/bool/any) or [`all()`](/mojo/std/builtin/bool/all). ([PR #2502](https://github.com/modular/modular/pull/2502)) * The [`SIMD.reduce_or()`](/mojo/std/builtin/simd/SIMD#reduce_or) and [`SIMD.reduce_and()`](/mojo/std/builtin/simd/SIMD#reduce_and) methods are now bitwise operations, and support integer types. ([PR #2671](https://github.com/modular/modular/pull/2671)) * Added [`SIMD.__repr__()`](/mojo/std/builtin/simd/SIMD#__repr__) to get the verbose string representation of `SIMD` types. ([PR #2728](https://github.com/modular/modular/pull/2728)) * [`math`](/mojo/std/math/math/) package: * The `math.bit` module has been moved to a new top-level [`bit`](/mojo/std/bit/bit/) module. The following functions in this module have been renamed: * `ctlz` -> `countl_zero` * `cttz` -> `countr_zero` * `bit_length` -> `bit_width` * `ctpop` -> `pop_count` * `bswap` -> `byte_swap` * `bitreverse` -> `bit_reverse` * The `math.rotate_bits_left()` and `math.rotate_bits_right()` functions have been moved to the `bit` module. * The `is_power_of_2()` function in the `math` module is now called `is_power_of_two()` and located in the `bit` module. * The `abs()`, `round()`, `min()`, `max()`, `pow()`, and `divmod()` functions have moved from `math` to `builtin`, so you no longer need to import these functions. * The `math.tgamma()` function has been renamed to [`math.gamma()`](/mojo/std/math/math/gamma) to conform with Python's naming. * The implementation of the following functions have been moved from the `math` module to the new [`utils.numerics`](/mojo/std/utils/numerics/) module: `isfinite()`, `isinf()`, `isnan()`, `nan()`, `nextafter()`, and `ulp()`. The functions continue to be exposed in the `math` module. * [`math.gcd()`](/mojo/std/math/math/gcd) now works on negative inputs, and like Python's implementation, accepts a variadic list of integers. New overloads for a `List` or `Span`of integers are also added. ([PR #2777](https://github.com/modular/modular/pull/2777)) * Async and coroutines: * [`Coroutine`](/mojo/std/builtin/coroutine/Coroutine) now requires a lifetime parameter. This parameter is set automatically by the parser when calling an async function. It contains the lifetimes of all the arguments and any lifetime accesses by the arguments. This ensures that argument captures by async functions keep the arguments alive as long as the coroutine is alive. * Async function calls are no longer allowed to borrow non-trivial register-passable types. Because async functions capture their arguments but register-passable types don't have lifetimes (yet), Mojo is not able to correctly track the reference, making this unsafe. To cover this safety gap, Mojo has temporarily disallowed binding non-trivial register-passable types to borrowed arguments in async functions. * Miscellaneous: * Added an [`InlineArray`](/mojo/std/collections/inline_array/InlineArray) type that works on memory-only types. Compare with the existing [`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple) type, which is conceptually an array type, but only works on `__TypeOfAllTypes`. ([PR \#2294](https://github.com/modular/modular/pull/2294)) * The [`base64`](/mojo/std/base64/) package now includes encoding and decoding support for both the Base64 and Base16 encoding schemes. ([PR #2364](https://github.com/modular/modular/pull/2364)) ([PR #2584](https://github.com/modular/modular/pull/2584)) * The `take()` function in [`Variant`](/mojo/std/utils/variant/Variant) and [`Optional`](/mojo/std/collections/optional/Optional) has been renamed to `unsafe_take()`. * The `get()` function in `Variant` has been replaced by `__getitem__()`. That is, `v.get[T]()` should be replaced with `v[T]`. * Various functions in the `algorithm` module are now built-in functions. This includes `sort()`, `swap()`, and `partition()`. `swap()` and `partition()` will likely shuffle around as we're reworking our built-in `sort()` function and optimizing it. * `infinity` and `NaN` are now correctly handled in [`testing.assert_almost_equal()`](/mojo/std/testing/testing/assert_almost_equal) and an `inf` function has been added to `utils/numerics.mojo`. ([PR #2375](https://github.com/modular/modular/pull/2375)) ### Tooling changes * Invoking `mojo package my-package -o my-dir` on the command line, where `my-package` is a Mojo package source directory, and `my-dir` is an existing directory, now outputs a Mojo package to `my-dir/my-package.mojopkg`. Previously, this had to be spelled out, as in `-o my-dir/my-package.mojopkg`. * The Mojo Language Server now reports a warning when a local variable is unused. * Several `mojo` subcommands now support a `--diagnostic-format` option that changes the format with which errors, warnings, and other diagnostics are printed. By specifying `--diagnostic-format json` on the command line, errors and other diagnostics will be output in a structured [JSON Lines](https://jsonlines.org) format that is easier for machines to parse. The full list of subcommands that support `--diagnostic-format` is as follows: `mojo build`, `mojo doc`, `mojo run`, `mojo package`, and `mojo test`. Further, the `mojo test --json` option has been subsumed into this new option; for the same behavior, run `mojo test --diagnostic-format json`. Note that the format of the JSON output may change; we don't currently guarantee its stability across releases of Mojo. * A new `--validate-doc-strings` option has been added to `mojo` to emit errors on invalid doc strings instead of warnings. * The `--warn-missing-doc-strings` flag for `mojo` has been renamed to `--diagnose-missing-doc-strings`. * A new decorator, `@doc_private`, was added that can be used to hide a declaration from being generated in the output of `mojo doc`. It also removes the requirement that the declaration has documentation (for example, when used with `--diagnose-missing-doc-strings`). * Debugger users can now set breakpoints on function calls in O0 builds even if the call has been inlined by the compiler. * The Mojo Language Server now supports renaming local variables. ### Other changes #### ❌ Removed * The `@unroll` decorator has been deprecated and removed. The decorator was supposed to guarantee that a decorated loop would be unrolled, or else the compiler would error. In practice, this guarantee was eroded over time, as a compiler-based approach cannot be as robust as the Mojo parameter system. In addition, the `@unroll` decorator did not make the loop induction variables parameter values, limiting its usefulness. Please see `@parameter for` for a replacement! * The method `object.print()` has been removed. Since `object` now conforms to the `Stringable` trait, you can use `print(my_object)` instead. * The following functions have been removed from the math module: * `clamp()`; use the new `SIMD.clamp()` method instead. * `round_half_down()` and `round_half_up()`; these can be trivially implemented using the `ceil()` and `floor()` functions. * `add()`, `sub()`, `mul()`, `div()`, `mod()`, `greater()`, `greater_equal()`, `less()`, `less_equal()`, `equal()`, `not_equal()`, `logical_and()`, `logical_xor()`, and `logical_not()`; Instead, users should rely directly on the corresponding operators (`+`, `-`, `*`, `/`, `%`, `>`, `>=`, `<`, `<=`, `==`, `!=`, `&`, `^`, and `~`). * `identity()` and `reciprocal()`; users can implement these trivially. * `select()`; removed in favor of using `SIMD.select()` directly. * `is_even()` and `is_odd()`; these can be trivially implemented using bitwise `&` with `1`. * `roundeven()`; the new `SIMD.roundeven()` method now provides the identical functionality. * `div_ceil()`; use the new `ceildiv()` function. * `rotate_left()` and `rotate_right()`; the same functionality is available in the builtin `SIMD.rotate_{left,right}()` methods for `SIMD` types, and the `bit.rotate_bits_{left,right})()` methods for `Int`. * An overload of `math.pow()` taking an integer parameter exponent. * `align_down_residual()`; it can be trivially implemented using `align_down()`. * `all_true()`, `any_true()`, and `none_true()`; use `SIMD.reduce_and()` and `SIMD.reduce_or()` directly. * `reduce_bit_count()`; use the new `SIMD.reduce_bit_count()` directly. * `rint()` and `nearbyint()`; use `round()` or `SIMD.roundeven()` as appropriate. * The `EvaluationMethod` has been removed from `math.polynomial` and Estrin's method is no longer available. This method was limited to degree 10 or less, underutilized, and its performance unclear. In the future, this might be reintroduced with an improved implementation if needed, when better performance benchmarking infrastructure is available. The default behavior of `math.polynomial.polynomial_evaluate()` is unchanged (Horner's method). * The `math.bit.select()` and `math.bit.bit_and()` functions have been removed. The same functionality is available in the builtin `SIMD.select` and `SIMD.__and__()` methods, respectively. * The `math.limit` module has been removed. The same functionality is available as follows: * `math.limit.inf()`: use `utils.numerics.max_or_inf()` * `math.limit.neginf()`: use `utils.numerics.min_or_neg_inf()` * `math.limit.max_finite()`: use `utils.numerics.max_finite()` * `math.limit.min_finite()`: use `utils.numerics.min_finite()` * The `tensor.random` module has been removed. The same functionality is now accessible via the `Tensor.rand()` and `Tensor.randn()` static methods. * The builtin `SIMD` struct no longer conforms to `Indexer`; users must explicitly cast `Scalar` values using `int`. #### 🛠️ Fixed * [#1837](https://github.com/modular/modular/issues/1837) Fix self-referential variant crashing the compiler. * [#2363](https://github.com/modular/modular/issues/2363) Fix LSP crashing on simple trait definitions. * [#1787](https://github.com/modular/modular/issues/1787) Fix error when using `//` on `FloatLiteral` in alias expression. * Made several improvements to dictionary performance. Dicts with integer keys are most heavily affected, but large dicts and dicts with large values will also see large improvements. * [#2692](https://github.com/modular/modular/issues/2692) Fix `assert_raises` to include calling location. ### Special thanks Special thanks to our community contributors: [@rd4com](https://github.com/rd4com), @toiletsandpaper, [@helehex](https://github.com/helehex), [@artemiogr97](https://github.com/artemiogr97), [@mikowals](https://github.com/mikowals), [@kernhanda](https://github.com/kernhanda), [@lsh](https://github.com/lsh), @LJ-9801, [@YichengDWu](https://github.com/YichengDWu), [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse), [@fknfilewalker](https://github.com/fknfilewalker), [@jayzhan211](https://github.com/jayzhan211), [@martinvuyk](https://github.com/martinvuyk), [@ChristopherLR](https://github.com/ChristopherLR), [@mzaks](https://github.com/mzaks), [@bgreni](https://github.com/bgreni), [@Brian-M-J](https://github.com/Brian-M-J), [@leandrolcampos](https://github.com/leandrolcampos) ## v24.3 (2024-05-02) ### ✨ Highlights * `AnyPointer` was renamed to [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and is now Mojo's preferred unsafe pointer type. It has several enhancements, including: * The element type can now be any type: it doesn't require `Movable`. * Because of this, the `take_value()`, `emplace_value()`, and `move_into()` methods have been changed to top-level functions and renamed. The new functions are: * [`initialize_pointee_copy`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_copy) * [`initialize_pointee_move`](/mojo/std/memory/unsafe_pointer/UnsafePointer#init_pointee_move) * [`move_from_pointee()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#take_pointee) * [`move_pointee`](/mojo/std/memory/unsafe_pointer/UnsafePointer#move_pointee_into) * A new [`destroy_pointee()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#destroy_pointee) function runs the destructor on the pointee. * `UnsafePointer` can be initialized directly from a `Reference` with `UnsafePointer(someRef)` and can convert to a reference with `yourPointer[]`. Both infer element type and address space. Note that when you convert a pointer to a reference, there's no way for Mojo to track the lifetime of the original value. So the resulting reference is no safer than the original pointer. * All of the pointer types received some cleanup to make them more consistent, for example the `unsafe.bitcast()` global function is now a consistent [`bitcast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#bitcast) method on the pointers, which can convert element type and address space. * Improvements to variadic arguments support. * Heterogeneous variadic pack arguments now work reliably even with memory types, and have a more convenient API to use, as defined by the [`VariadicPack`](/mojo/std/builtin/variadics/VariadicPack) type. For example, a simplified version of `print` can be implemented like this: ```mojo fn print[T: Stringable, *Ts: Stringable](first: T, *rest: *Ts): print_string(str(first)) @parameter fn print_elt[T: Stringable](a: T): print_string(" ") print_string(a) rest.each[print_elt]() ``` * Mojo now supports declaring functions that have both optional and variadic arguments, both positional and keyword-only. For example, this now works: ```mojo fn variadic_arg_after_default( a: Int, b: Int = 3, *args: Int, c: Int, d: Int = 1, **kwargs: Int ): ... ``` Positional variadic parameters also work in the presence of optional parameters. That is: ```mojo fn variadic_param_after_default[e: Int, f: Int = 2, *params: Int](): pass ``` Note that variadic keyword parameters are not supported yet. For more information, see [Variadic arguments](/mojo/manual/functions#variadic-arguments) in the Mojo Manual. * The `mojo build` and `mojo run` commands now support a `-g` option. This shorter alias is equivalent to writing `--debug-level full`. This option is also available in the `mojo debug` command, but is already the default. * Many new standard library APIs have been filled in, including many community contributions. Changes are listed in the standard library section. * The Mojo Manual has a new page on [Types](/mojo/manual/types). ### Language changes * Certain dunder methods that take indices (`__getitem__()`, `__setitem__()`, and `__refitem__()`) or names (`__getattr__()` and `__setattr__()`) can now take the index or name as a parameter value instead of an argument value. This is enabled when you define one of these methods with no argument other than `self` (for a getter) or `self` and the set value (for a setter). This enables types that can only be subscripted into with parameters, as well as things like the following example, which passes the attribute name as a parameter so that attribute names can be checked at compile time. ```mojo struct RGB: fn __getattr__[name: StringLiteral](self) -> Int: @parameter if name == "r": return ... elif name == "g": return ... else: constrained[name == "b", "can only access with r, g, or b members"]() return ... var rgb = RGB() print(rgb.b) # Works print(rgb.q) # Compile error ``` * Mojo now allows users to capture the source location of code and call location of functions dynamically using the `__source_location()` and `__call_location()` functions. For example: ```mojo from builtin._location import __call_location @always_inline fn my_assert(cond: Bool, msg: String): if not cond: var call_loc = __call_location() print("In", call_loc.file_name, "on line", str(call_loc.line) + ":", msg) fn main(): my_assert(False, "always fails") # some_file.mojo, line 193 ``` This prints "`In /path/to/some_file.mojo on line 193: always fails`". Note that `__call_location()` only works in `@always_inline` or `@always_inline("nodebug")` functions. It gives incorrect results if placed in an `@always_inline` function that's called *from* an `@always_inline("nodebug")` function. This feature is still evolving and for the time being you need to explicitly import these APIs, as shown above. In the future, these will probably be built-in functions and not require an import statement. Neither `__source_location()` nor `__call_location()` work when called in a parameter context. For example: ```mojo from builtin._location import __call_location @always_inline fn mystery_location() -> String: var loc = __call_location() return str(loc.file_name) def main(): alias doesnt_work = mystery_location() # ``` ### Standard library changes #### ⭐️ New * [`List`](/mojo/std/collections/list/List) has several new methods: * `pop(index)` for removing an element at a particular index. By default, `List.pop()` removes the last element in the list. (@LJ-9801, fixes [#2017](https://github.com/modular/modular/issues/2017)) * `resize(new_size)` for resizing the list without the need to specify an additional value. ([@mikowals](https://github.com/mikowals), fixes [#2133](https://github.com/modular/modular/issues/2133)) * `insert(index, value)` for inserting a value at a specified index into the `List`. ([@whym1here](https://github.com/whym1here), fixes [#2134](https://github.com/modular/modular/issues/2134)) * A new constructor `List(ptr, size, capacity)` to to avoid needing to do a deep copy of an existing contiguous memory allocation when constructing a new `List`. ([@StandinKP](https://github.com/StandinKP), fixes [#2170](https://github.com/modular/modular/issues/2170)) * [`Dict`](/mojo/std/collections/dict/Dict) now has a `update()` method to update keys/values from another `Dict`. ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) * [`Set`](/mojo/std/collections/set/Set) now has named methods for set operations: * `difference()` mapping to `-` * `difference_update()` mapping to `-=` * `intersection_update()` mapping to `&=` * `update()` mapping to `|=` ([@arvindavoudi](https://github.com/arvindavoudi)) * `Dict`, `List`, and `Set` all conform to the `Boolable` trait. The collections evaluate to `True` if they contain any elements, `False` otherwise: ```mojo def list_names(names: List[String]): if names: for name in names: print(name[]) else: print("No names to list.") ``` ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) * Added [`reversed()`](/mojo/std/builtin/reversed/reversed) function for creating reversed iterators. Several range types, `List`, and `Dict` now support iterating in reverse. ```mojo var numbers = List(1, 2, 3, 4, 5) for number in reversed(numbers): print(number) ``` ([@helehex](https://github.com/helehex) and [@jayzhan211](https://github.com/jayzhan211), contributes towards [#2325](https://github.com/modular/modular/issues/2325)) * [`Optional`](/mojo/std/collections/optional/Optional) now implements `__is__` and `__isnot__` methods so that you can compare an `Optional` with `None`. For example: ```mojo var opt = Optional(1) if opt is not None: print(opt.value()[]) ``` ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) * [`Tuple`](/mojo/std/builtin/tuple/Tuple) now works with memory-only element types like `String` and allows you to directly index into it with a parameter expression. This means you can now simply use `x = tup[1]` like Python instead of `x = tup.get[1, Int]()`. You can also assign into tuple elements now as well with `tup[1] = x`. ```mojo var tuple = ("Green", 9.3) var name = tuple[0] var value = tuple[1] ``` Note that because the subscript must be a parameter expression, you can't iterate through a `Tuple` using an ordinary `for` loop. * The `Reference` type has several changes, including: * It has moved to the `memory.reference` module instead of `memory.unsafe`. * `Reference` now has an [`unsafe_bitcast()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#bitcast) method, similar to the pointer types. * Several unsafe methods were removed, including `offset()`, `destroy_element_unsafe()` and `emplace_ref_unsafe()`. This is because `Reference` is a safe type—use `UnsafePointer` to do unsafe operations. * [`Bool`](/mojo/std/builtin/bool/Bool) can now be implicitly converted from any type conforming to the [`Boolable`](/mojo/std/builtin/bool/Boolable) trait. This means that you no longer need to write code like this: ```mojo @value struct MyBoolable: fn __bool__(self) -> Bool: ... fn takes_boolable[T: Boolable](cond: T): ... takes_boolable(MyBoolable()) ``` Instead, you can simply write: ```mojo fn takes_bool(cond: Bool): ... takes_bool(MyBoolable()) ``` Note that calls to `takes_bool()` will perform the implicit conversion, so in some cases is it still better to explicitly declare a type parameter, e.g.: ```mojo fn takes_two_boolables[T: Boolable](a: T, b: T): # Short circuit means `b.__bool__()` might not be evaluated. if a.__bool__() and b.__bool__(): ... ``` * [`PythonObject`](/mojo/std/python/python_object/PythonObject) now conforms to the [`KeyElement`](/mojo/std/collections/dict/#keyelement) trait, meaning that it can be used as key type for [`Dict`](/mojo/std/collections/dict/Dict). This allows you to easily build and interact with Python dictionaries in Mojo: ```mojo def main(): d = PythonObject(Dict[PythonObject, PythonObject]()) d["foo"] = 12 d[7] = "bar" d["foo"] = [1, 2, "something else"] print(d) # prints `{'foo': [1, 2, 'something else'], 7: 'bar'}` ``` * [`FileHandle.seek()`](/mojo/std/builtin/file/FileHandle#seek) now has a `whence` argument that defaults to `os.SEEK_SET` to seek from the beginning of the file. You can now set to `os.SEEK_CUR` to offset by the current `FileHandle` seek position: ```mojo var f = open("/tmp/example.txt") # Skip 32 bytes f.seek(os.SEEK_CUR, 32) ``` Or `os.SEEK_END` to offset from the end of file: ```mojo # Start from 32 bytes before the end of the file f.seek(os.SEEK_END, -32) ``` * [`FileHandle.read()`](/mojo/std/builtin/file/FileHandle#read) can now read straight into a [`DTypePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer): ```mojo var file = open("/tmp/example.txt", "r") # Allocate and load 8 elements var ptr = DTypePointer[DType.float32].alloc(8) var bytes = file.read(ptr, 8) print("bytes read", bytes) print(ptr.load[width=8]()) ``` * The `sys` module now contains an `exit()` function that would exit a Mojo program with the specified error code. ```mojo from sys import exit exit(0) ``` * The constructors for `Tensor` have been changed to be more consistent. As a result, constructors take the shape as the first argument (instead of the second) when constructing a tensor with pointer data. If you pass a single scalar value to the `Tensor` constructor, it now broadcasts the value to all elements in the tensor. For example, `Tensor[DType.float32](TensorShape(2,2), 0)` constructs a `2x2` tensor initialized with all zeros. This provides an easy way to fill in the data of a tensor. * [`String`](/mojo/std/collections/string/string/String) now has `removeprefix()` and `removesuffix()` methods. ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) * The [`ord`](/mojo/std/collections/string/string/ord) and [`chr`](/mojo/std/collections/string/string/chr) functions have been improved to accept any Unicode character. ([@mzaks](https://github.com/mzaks), contributes towards [#1616](https://github.com/modular/modular/issues/1616)) * [`atol()`](/mojo/std/collections/string/string/atol) now handles whitespace. The `atol()`function is used internally by `String.__int__()`, so `int(String( " 10 "))` now returns `10` instead of raising an error. ([@artemiogr97](https://github.com/artemiogr97)) * [`SIMD`](/mojo/std/builtin/simd/SIMD) now implements the `__rmod__()` method. ([@bgreni](https://github.com/bgreni), fixes [#1482](https://github.com/modular/modular/issues/1482)) * [`bool(None)`](/mojo/std/builtin/bool/bool-function) is now implemented. (@zhoujingya) * The [`DTypePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) type now implements `gather()` for gathering a `SIMD` vector from offsets of a current pointer. Similarly, support for `scatter()` was added to scatter a `SIMD` vector into offsets of the current pointer. ([@leandrolcampos](https://github.com/leandrolcampos)) * The [`len()`](/mojo/std/builtin/len/len) function now handles a [`range()`](/mojo/std/builtin/range/range) specified with a negative end value, so that things like `len(range(-1))` work correctly. ([@soraros](https://github.com/soraros)) * [`debug_assert()`](/mojo/std/builtin/debug_assert/debug_assert) now prints its location (filename, line, and column where it was called) in its error message. Similarly, the `assert` helpers in the [`testing`](/mojo/std/testing/testing/) module now include location information in their messages. * The [`testing.assert_equal[SIMD]()`](/mojo/std/testing/testing/assert_equal) function now raises if any of the elements mismatch in the two `SIMD` arguments being compared. ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) * The [`testing.assert_almost_equal()`](/mojo/std/testing/testing/assert_almost_equal) and [`math.isclose()`](/mojo/std/math/math/isclose) functions now have an `equal_nan` flag. When set to `True`, then NaNs are considered equal. * The `object` type now supports the division, modulo, and left and right shift operators, including the in-place and reverse variants. (@LJ-9801, fixes [#2224](https://github.com/modular/modular/issues/2224)) * Added checked arithmetic operations for `SIMD` integers. `SIMD` integer types (including the sized integer scalars like `Int64`) can now perform checked additions, subtractions, and multiplications using the following new methods: * `add_with_overflow()` * `sub_with_overflow()` * `mul_with_overflow()` Checked arithmetic allows the caller to determine if an operation exceeded the numeric limits of the type. For example: ```mojo var simd = SIMD[DType.int8, 4](7, 11, 13, 17) var product: SIMD[DType.int8, 4] var overflow: SIMD[DType.bool, 4] (product, overflow) = simd.mul_with_overflow(simd) for i in range(len(product)): if overflow[i]: print("") else: print(product[i]) ``` ([@lsh](https://github.com/lsh)) * Added [`os.remove()`](/mojo/std/os/os/remove) and [`os.unlink()`](/mojo/std/os/os/unlink) for deleting files. ([@artemiogr97](https://github.com/artemiogr97), fixes [#2306](https://github.com/modular/modular/issues/2306)) #### 🦋 Changed * The [`parallel_memcpy()`](/mojo/std/algorithm/memory/parallel_memcpy) function has moved from the `buffer` package to the `algorithm` package. Please update your imports accordingly. * [`Optional.value()`](/mojo/std/collections/optional/Optional#value) now returns a reference instead of a copy of the contained value. To perform a copy manually, dereference the result: ```mojo var result = Optional(123) var value = result.value()[] ``` ([@lsh](https://github.com/lsh), fixes [#2179](https://github.com/modular/modular/issues/2179)) * Per the accepted community proposal, [Standardize the representation of byte sequence as a sequence of unsigned 8-bit integers](https://github.com/modular/modular/blob/main/mojo/proposals/byte-as-uint8.md), began transition to using `UInt8` by changing the data pointer of `Error` to `DTypePointer[DType.uint8]`. ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse), contributes towards [#2317](https://github.com/modular/modular/issues/2317)) * Continued transition to `UnsafePointer` from the legacy `Pointer` type in various standard library APIs and internals. ([@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) ### Tooling changes * The behavior of `mojo build` when invoked without an output `-o` argument has changed slightly: `mojo build ./test-dir/program.mojo` now outputs an executable to the path `./program`, whereas before it would output to the path `./test-dir/program`. * The `mojo package` command no longer supports the `-D` flag. All compilation environment flags should be provided at the point of package use (e.g. `mojo run` or `mojo build`). * The REPL no longer allows type level variable declarations to be uninitialized, e.g. it will reject `var s: String`. This is because it does not do proper lifetime tracking (yet!) across cells, and so such code would lead to a crash. You can work around this by initializing to a dummy value and overwriting later. This limitation only applies to top level variables, variables in functions work as they always have. ### Other changes #### Low-level language changes * A low-level `__get_mvalue_as_litref(x)` builtin was added to give access to the underlying memory representation as a `!lit.ref` value without checking initialization status of the underlying value. This is useful in very low-level logic but isn't designed for general usability and will likely change in the future. * Properties can now be specified on inline MLIR ops: ```mojo _ = __mlir_op.`kgen.source_loc`[ _type = ( __mlir_type.index, __mlir_type.index, __mlir_type.`!kgen.string` ), _properties = __mlir_attr.`{inlineCount = 1 : i64}`, ]() ``` As the example shows above, the protected `_properties` attribute can be passed during op construction, with an MLIR `DictionaryAttr` value. #### ❌ Removed * Support for "register only" variadic packs has been removed. Instead of `AnyRegType`, please upgrade your code to `AnyType` in examples like this: ```mojo fn your_function[*Types: AnyRegType](*args: *Ts): ... ``` This move gives you access to a nicer API and has the benefit of being memory safe and correct for non-trivial types. If you need specific APIs on the types, please use the correct trait instead of `AnyType`. * `List.pop_back()` has been removed. Use `List.pop()` instead which defaults to popping the last element in the list. * `SIMD.to_int(value)` has been removed. Use `int(value)` instead. * The `__get_lvalue_as_address(x)` magic function has been removed. To get a reference to a value use `Reference(x)` and if you need an unsafe pointer, you can use `UnsafePointer.address_of(x)`. #### 🛠️ Fixed * [#516](https://github.com/modular/modular/issues/516) and [#1817](https://github.com/modular/modular/issues/1817) and many others, e.g. "Can't create a function that returns two strings." * [#1178](https://github.com/modular/modular/issues/1178) (os/kern) failure (5). * [#1609](https://github.com/modular/modular/issues/1609) alias with `DynamicVector[Tuple[Int]]` fails. * [#1987](https://github.com/modular/modular/issues/1987) Defining `main` in a Mojo package is an error, for now. This is not intended to work yet, erroring for now will help to prevent accidental undefined behavior. * [#1215](https://github.com/modular/modular/issues/1215) and [#1949](https://github.com/modular/modular/issues/1949) The Mojo LSP server no longer cuts off hover previews for functions with functional arguments, parameters, or results. * [#1901](https://github.com/modular/modular/issues/1901) Fixed Mojo LSP and documentation generation handling of inout arguments. * [#1913](https://github.com/modular/modular/issues/1913) - `0__` no longer crashes the Mojo parser. * [#1924](https://github.com/modular/modular/issues/1924) JIT debugging on Mac has been fixed. * [#1941](https://github.com/modular/modular/issues/1941) Mojo variadic arguments don't work with non-trivial register-only types. * [#1963](https://github.com/modular/modular/issues/1963) `a!=0` is now parsed and formatted correctly by `mojo format`. * [#1676](https://github.com/modular/modular/issues/1676) Fix a crash related to `@value` decorator and structs with empty body. * [#1917](https://github.com/modular/modular/issues/1917) Fix a crash after syntax error during tuple creation. * [#2006](https://github.com/modular/modular/issues/2006) The Mojo LSP now properly supports signature types with named arguments and parameters. * [#2007](https://github.com/modular/modular/issues/2007) and [#1997](https://github.com/modular/modular/issues/1997) The Mojo LSP no longer crashes on certain types of closures. * [#1675](https://github.com/modular/modular/issues/1675) Ensure `@value` decorator fails gracefully after duplicate field error. * [#2068](https://github.com/modular/modular/issues/2068) Fix `SIMD.reduce()` for size\_out == 2. ([@soraros](https://github.com/soraros)) ## v24.2.1 (2024-04-11) This release doesn't include any changes to Mojo. ## v24.2 (2024-03-28) ### 🔥 Legendary * The Mojo standard library is now open source! Check out the [README](https://github.com/modular/modular/blob/main/mojo/stdlib/README.md) for everything you need to get started. * Structs and other nominal types are now allowed to implicitly conform to traits. A struct implicitly conforms to a trait if it implements all the requirements for the trait. For example, any struct that implements the `__str__()` method implicitly conforms to `Stringable`, and is usable with the `str()` built-in function. ```mojo @value struct Foo: fn __str__(self) -> String: return "foo!" fn main(): print(str(Foo())) # prints 'foo!' ``` We still strongly encourage you to explicitly list the traits a struct conforms to when possible: ```mojo @value struct Foo(Stringable): ... ``` Not only is this useful for documentation and for communicating intentions, but in the future, explicit conformance will be useful for features like default methods and extensions. * Mojo's Python interoperability now supports passing keyword arguments to Python functions: ```mojo from python import Python def main(): plt = Python.import_module("matplotlib.pyplot") plt.plot((5, 10), (10, 15), color="red") plt.show() ``` ### Language changes #### ⭐️ New * Mojo now has support for variadic keyword arguments, often referred to as `**kwargs`. This means you can now declare and call functions like this: ```mojo fn print_nicely(**kwargs: Int) raises: for key in kwargs.keys(): print(key[], "=", kwargs[key[]]) # prints: # `a = 7` # `y = 8` print_nicely(a=7, y=8) ``` For more details (and a list of current limitations), see [Variadic keyword arguments](/mojo/manual/functions#variadic-keyword-arguments) in the Mojo manual. #### 🦋 Changed or removed * `let` declarations now produce a compile time error instead of a warning, our next step in [removing let declarations](https://github.com/modular/modular/blob/main/mojo/proposals/remove-let-decls.md). The compiler still recognizes the `let` keyword for now in order to produce a good error message, but that will be removed in subsequent releases. * Mojo now warns about unused values in both `def` and `fn` declarations, instead of completely disabling the warning in `def`s. It never warns about unused `object` or `PythonObject` values, tying the warning to these types instead of the kind of function they are unused in. This will help catch API usage bugs in `def`s and make imported Python APIs more ergonomic in `fn`s. * For the time being, dynamic type values will be disabled in the language. For example, the following will now fail with an error: ```mojo var t = Int # dynamic type values not allowed struct SomeType: ... takes_type(SomeType) # dynamic type values not allowed ``` We want to take a step back and (re)design type valued variables, existentials, and other dynamic features. This does not affect type valued **parameters**, so the following works as before: ```mojo alias t = Int # still 🔥 struct SomeType: ... takes_type[SomeType]() # already 🔥 >fn uses_trait[T: SomeTrait](value: T): ... # still 🔥 ``` * The `*_` expression in parameter expressions is now required to occur at the end of a positional parameter list, instead of being allowed in the middle. ```mojo # No longer supported alias FirstUnbound = SomeStruct[*_, 42] alias MidUnbound = SomeStruct[7, *_, 6] # Still supported alias LastUnbound = SomeStruct[42, *_] ``` We narrowed this because we want to encourage type designers to get the order of parameters right, and want to extend `*_` to support keyword parameters as well in the future. ### Standard library changes #### ⭐️ New * `DynamicVector` has been renamed to [`List`](/mojo/std/collections/list/List), and has moved from the `collections.vector` module to the `collections.list` module. In addition: * You can now construct a `List` from a variadic number of values. For example: ```mojo var numbers = List[Int](1, 2, 3) ``` * `List` and [`InlinedFixedVector`](/mojo/std/collections/inline_array/InlineArray) types now support negative indexing. This means that you can write `vec[-1]` which is equivalent to `vec[len(vec)-1]`. * `List.push_back()` has been removed. Please use the `append()` function instead. * The [`print()`](/mojo/std/builtin/io/print) function now takes `sep` and `end` keyword arguments. This means that you can write: ```mojo print("Hello", "Mojo", sep=", ", end="!!!\n") # prints Hello, Mojo!!! ``` `sep` defaults to the empty string and `end` defaults to "\n". Also, the `print_no_newline()` function has been removed. Please use `print(end="")` instead. * The [`FloatLiteral`](/mojo/std/builtin/float_literal/FloatLiteral) type is now an infinite-precision nonmaterializable type. This means you can do compile-time calculations using `FloatLiteral` without rounding errors. When materialized at runtime, a `FloatLiteral` value is converted to a [`Float64`](/mojo/std/builtin/simd). ```mojo # third is an infinite-precision FloatLiteral value alias third = 1.0 / 3.0 # t is a Float64 var t = third ``` * String types all conform to the [`IntableRaising`](/mojo/std/builtin/int/IntableRaising) trait. This means that you can now call `int("123")` to get the integer `123`. If the integer cannot be parsed from the string, then an error is raised. * The `Tensor` type now has `argmax()` and `argmin()` functions to compute the position of the max or min value. Note: this should return a `Tensor[Int]` but currently the output tensor is the same type as the input tensor. This will be fixed in a future release. * Added a new [`collections.OptionalReg`](/mojo/std/collections/optional/OptionalReg) type, a register-passable alternative to [`Optional`](/mojo/std/collections/optional/Optional). * The [`ulp()`](/mojo/std/utils/numerics/ulp) function has been added to the `math` module. This allows you to get the units of least precision (or units of last place) of a floating point value. #### 🦋 Changed * The `simd_load()`, `simd_store()`, `aligned_simd_load()`, and `aligned_simd_store()` methods on [`DTypePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer), `Buffer`, and `NDBuffer` have been merged into a more expressive set of `load()` and `store()` methods with keyword-only `width` and `alignment` parameters: ```mojo # Doesn't work my_simd = my_buffer.simd_load[simd_width](index) # Works my_simd = my_buffer.load[width=simd_width](index) # Doesn't work my_buffer.aligned_simd_store[width, alignment](my_simd) # Works my_buffer.store[width=width, alignment=alignment](my_simd) ``` * The [`EqualityComparable`](/mojo/std/builtin/equality_comparable/EqualityComparable) trait now requires the `__ne__()` method for conformance in addition to the previously required `__eq__()` method. * Many types now declare conformance to `EqualityComparable` trait. * [`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple) parameter order has changed to `StaticTuple[type, size]` for consistency with `SIMD` and similar collection types. * The signature of the [`elementwise()`](/mojo/std/algorithm/functional/elementwise) function has been changed. The new order is is `function`, `simd_width`, and then `rank`. As a result, the rank parameter can now be inferred and one can call `elementwise()` without it: ```mojo elementwise[func, simd_width](shape) ``` * `PythonObject` is now register-passable. * `PythonObject.__iter__()` now works correctly on more types of iterable Python objects. Attempting to iterate over non-iterable objects will now raise an exception instead of behaving as if iterating over an empty sequence. `__iter__()` also now borrows `self` rather than requiring `inout`, allowing code like: ```mojo for value in my_dict.values(): ... ``` #### 🚚 Moved * We took the opportunity to rehome some modules into their correct package as we were going through the process of open-sourcing the Mojo standard library. Specifically, the following are some breaking changes worth calling out. Please update your import statements accordingly. * `Buffer`, `NDBuffer`, and friends have moved from the `memory` package into a new `buffer` package. ```mojo from buffer import Buffer, NDBuffer ``` * `utils.list`, including the [`Dim`](/mojo/std/buffer/dimlist/Dim) and [`DimList`](/mojo/std/buffer/dimlist/DimList) types, has moved to the `buffer` package. ```mojo from buffer import Dim, DimList ``` * The [`parallel_memcpy()`](/mojo/std/algorithm/memory/parallel_memcpy) function has moved from the `memory` package into the `buffer` package. ```mojo from buffer import parallel_memcpy ``` * The [`rand()`](/mojo/kernels/extensibility/tensor/tensor/Tensor/#rand) and [`randn()`](/mojo/kernels/extensibility/tensor/tensor/Tensor/#randn) functions from the `random` package that return a `Tensor` have moved to the `tensor` package. Note that the overloads that write to a `DTypePointer` remain in the `random` package. If you happen to be using both versions in the same source file, you can import them both using the `import as` syntax: ```mojo from tensor import rand from random import rand as rand_dt ``` * The `trap()` function has been renamed to [`abort()`](/mojo/std/os/os/abort). It also has moved from the `debug` module to the `os` module. ```mojo from os import abort ``` * The [`isinf()`](/mojo/std/utils/numerics/isfinite) and [`isfinite()`](/mojo/std/utils/numerics/isfinite) methods have been moved from `math.limits` to the `math` module. ```mojo from math import ininf, isfinite ``` ### Tooling changes #### ⭐️ New * Docstring code blocks can now use `%#` to hide lines of code from documentation generation. For example: ```mojo var value = 5 %# print(value) ``` Will generate documentation of the form: ```mojo var value = 5 ``` Hidden lines are processed as if they were normal code lines during test execution. This allows for writing additional code within a docstring example that is only used to ensure the example is runnable/testable. * The Mojo LSP server now allow you to specify additional search paths to use when resolving imported modules in a document. You can specify search paths on the command line, using the `-I` option, or you can add them to the `mojo.lsp.includeDirs` setting in the VS Code extension. ### Other changes #### ❌ Removed * The `__get_address_as_lvalue` magic function has been removed. You can now get an LValue from a `Pointer` or `Reference` by using the dereference operator (`[]`): ```mojo var ptr: Pointer[MyRecord] ... # Doesn't work __get_address_as_lvalue(ptr.value) = MyRecord(3, 5) # Works ptr[] = MyRecord(3, 5) ``` * The type parameter for the `memcpy` function is now automatically inferred. This means that calls to `memcpy` of the form `memcpy[Dtype.xyz](...)` will no longer work and the user would have to change the code to `memcpy(...)`. * The [`memcpy()`](/mojo/std/memory/memory/memcpy) overload that worked on `Buffer` types has been removed in favor of just overloads for [`Pointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and [`DTypePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer): ```mojo # Doesn't work memcpy(destBuffer, srcBuffer, count) # Works memcpy(destBuffer.data, srcBuffer.data, count) ``` * The functions `max_or_inf()`, `min_or_neginf()` have been removed from `math.limit`. These functions were only used by the SIMD type. * As mentioned previously, the `print_no_newline()` function has been removed. Please use `print(end="")` instead. #### 🛠️ Fixed * [#1362](https://github.com/modular/modular/issues/1362) - Parameter inference now recursively matches function types. * [#951](https://github.com/modular/modular/issues/951) - Functions that were both `async` and `@always_inline` incorrectly errored. * [#1858](https://github.com/modular/modular/issues/1858) - Trait with parametric methods regression. * [#1892](https://github.com/modular/modular/issues/1892) - Forbid unsupported decorators on traits. * [#1735](https://github.com/modular/modular/issues/1735) - Trait-typed values are incorrectly considered equal. * [#1909](https://github.com/modular/modular/issues/1909) - Crash due to nested import in unreachable block. * [#1921](https://github.com/modular/modular/issues/1921) - Parser crashes binding `Reference` to lvalue with subtype lifetime. * [#1945](https://github.com/modular/modular/issues/1945) - `Optional[T].or_else()` should return `T` instead of `Optional[T]`. * [#1940](https://github.com/modular/modular/issues/1940) - Constrain `math.copysign` to floating point or integral types. * [#1838](https://github.com/modular/modular/issues/1838) - Variadic `print` does not work when specifying `end=""` * [#1826](https://github.com/modular/modular/issues/1826) - The `SIMD.reduce` methods correctly handle edge cases where `size_out >= size`. ## v24.1.1 (2024-03-18) This release includes installer improvements and enhanced error reporting for installation issues. Otherwise it is functionally identical to Mojo 24.1. ## v24.1 (2024-02-29) ### 🔥 Legendary * Mojo is now bundled with [the MAX platform](/max)! As such, the Mojo package version now matches the MAX version, which follows a `YY.MAJOR.MINOR` version scheme. Because this is our first release in 2024, that makes this version `24.1`. * Mojo debugging support is here! The Mojo VS Code extension includes debugger support. For details, see [Debugging](/mojo/tools/debugging) in the Mojo Manual. ### ⭐️ New * We now have a [`Set`](/mojo/std/collections/set/Set) type in our collections! `Set` is backed by a `Dict`, so it has fast add, remove, and `in` checks, and requires member elements to conform to the `KeyElement` trait. ```mojo from collections import Set var set = Set[Int](1, 2, 3) print(len(set)) # 3 set.add(4) for element in set: print(element[]) set -= Set[Int](3, 4, 5) print(set == Set[Int](1, 2)) # True print(set | Set[Int](0, 1) == Set[Int](0, 1, 2)) # True let element = set.pop() print(len(set)) # 1 ``` * Mojo now supports the `x in y` expression as syntax sugar for `y.__contains__(x)` as well as `x not in y`. * Mojo now has support for keyword-only arguments and parameters. For example: ```mojo fn my_product(a: Int, b: Int = 1, *, c: Int, d: Int = 2): print(a * b * c * d) my_product(3, c=5) # prints '30' my_product(3, 5, d=7) # error: missing 1 required keyword-only argument: 'c' ``` This includes support for declaring signatures that use both variadic and keyword-only arguments/parameters. For example, the following is now possible: ```mojo fn prod_with_offset(*args: Int, offset: Int = 0) -> Int: var res = 1 for i in range(len(args)): res *= args[i] return res + offset print(prod_with_offset(2, 3, 4, 10)) # prints 240 print(prod_with_offset(2, 3, 4, offset=10)) # prints 34 ``` Note that variadic keyword-only arguments/parameters (for example, `**kwargs`) are not supported yet. That is, the following is not allowed: ```mojo fn variadic_kw_only(a: Int, **kwargs): ... ``` For more information, see [Positional-only and keyword-only arguments](/mojo/manual/functions#positional-only-and-keyword-only-arguments) in the Mojo Manual. * The `print()` function now accepts a keyword-only argument for the `end` which is useful for controlling whether a newline is printed or not after printing the elements. By default, `end` defaults to "\n" as before. * The Mojo SDK can now be installed on AWS Graviton instances. * A new version of the Mojo Playground is available. The new playground is a simple interactive editor for Mojo code, similar to the Rust Playground or Go Playground. The old JupyterLab based playground will remain online until March 20th. * The Mojo LSP server will now generate fixits for populating empty documentation strings: ```mojo fn foo(arg: Int): """""" # Unexpected empty documentation string ``` Applying the fixit from above will generate: ```mojo fn foo(arg: Int): """[summary]. Args: arg: [description]. """ ``` * Added new `*_` syntax that allows users to explicitly unbind any number of positional parameters. For example: ```mojo struct StructWithDefault[a: Int, b: Int, c: Int = 8, d: Int = 9]: pass alias all_unbound = StructWithDefault[*_] # equivalent to alias all_unbound = StructWithDefault[_, _, _, _] alias first_bound = StructWithDefault[5, *_] # equivalent to alias first_bound = StructWithDefault[5, _, _, _] alias last_bound = StructWithDefault[*_, 6] # equivalent to alias last_bound = StructWithDefault[_, _, _, 6] alias mid_unbound = StructWithDefault[3, *_, 4] # equivalent to alias mid_unbound = StructWithDefault[3, _, _, 4] ``` As demonstrated above, this syntax can be used to explicitly unbind an arbitrary number of parameters, at the beginning, at the end, or in the middle of the operand list. Since these unbound parameters must be explicitly specified at some point, default values for these parameters are not applied. For example: ```mojo alias last_bound = StructWithDefault[*_, 6] # When using last_bound, you must specify a, b, and c. last_bound # doesn't have a default value for `c`. var s = last_bound[1, 2, 3]() ``` For more information see the Mojo Manual sections on [partially-bound types](/mojo/manual/parameters/#fully-bound-partially-bound-and-unbound-types) and [automatic parameterization of functions](/mojo/manual/parameters/#automatic-parameterization). * [`DynamicVector`](/mojo/std/collections/list/List) now supports iteration. Iteration values are instances of `Reference` and require dereferencing: ```mojo var v: DynamicVector[String]() v.append("Alice") v.append("Bob") v.append("Charlie") for x in v: x[] = str("Hello, ") + x[] for x in v: print(x[]) ``` * `DynamicVector` now has [`reverse()`](/mojo/std/collections/list/List#reverse) and [`extend()`](/mojo/std/collections/list/List#extend) methods. * The `mojo package` command now produces compilation agnostic packages. Compilation options such as O0, or --debug-level, are no longer needed or accepted. As a result, packages are now smaller, and extremely portable. * Initializers for `@register_passable` values can (and should!) now be specified with `inout self` arguments just like memory-only types: ```mojo @register_passable struct YourPair: var a: Int var b: Int fn __init__(inout self): self.a = 42 self.b = 17 fn __copyinit__(inout self, existing: Self): self.a = existing.a self.b = existing.b ``` This form makes the language more consistent, more similar to Python, and easier to implement advanced features for. There is also no performance impact of using this new form: the compiler arranges to automatically return the value in a register without requiring you to worry about it. The older `-> Self` syntax is still supported in this release, but will be removed in a subsequent one, so please migrate your code. One thing to watch out for: a given struct should use one style or the other, mixing some of each won't work well. * The `inout self` initializer form is **required** for initializers of `@register_passable` types that may raise errors: ```mojo @register_passable struct RaisingCtor: fn __init__(inout self) raises: raise ``` * `async` functions that may raise errors have been temporarily disabled in this build. The implementation of Mojo async is undergoing a rework 🚧. * The standard library `slice` type has been renamed to [`Slice`](/mojo/std/builtin/builtin_slice/Slice), and a `slice` function has been introduced. This makes Mojo closer to Python and makes the `Slice` type follow the naming conventions of other types like `Int`. * "Slice" syntax in subscripts is no longer hard coded to the builtin `slice` type: it now works with any type accepted by a container's `__getitem__()` method. For example: ```mojo @value struct UnusualSlice: var a: Int var b: Float64 var c: String struct YourContainer: fn __getitem__(self, slice: UnusualSlice) -> T: ... ``` Given this implementation, you can subscript into an instance of `YourContainer` like `yc[42:3.14:"🔥"]` and the three values are passed to the `UnusualSlice` constructor. * The `__refitem__()` accessor method may now return a `Reference` instead of having to return an MLIR internal reference type. * Added [`AnyPointer.move_into()`](/mojo/std/memory/unsafe_pointer/UnsafePointer#move_pointee_into) method, for moving a value from one pointer memory location to another. * Added built-in [`hex()`](/mojo/std/builtin/format_int/hex) function, which can be used to format any value whose type implements the [`Intable`](/mojo/std/builtin/int/Intable) trait as a hexadecimal string. * [`PythonObject`](/mojo/std/python/python_object/PythonObject) now implements `__is__` and `__isnot__` so that you can use expressions of the form `x is y` and `x is not y` with `PythonObject`. * [`PythonObject`](/mojo/std/python/python_object/PythonObject) now conforms to the `SizedRaising` trait. This means the built-in [`len()`](/mojo/std/builtin/len/len) function now works on `PythonObject`. * The `os` package now contains the [`stat()`](/mojo/std/os/fstat/stat) and [`lstat()`](/mojo/std/os/fstat/lstat) functions. * A new [`os.path`](/mojo/std/os/path/path) package now allows you to query properties on paths. * The `os` package now has a [`PathLike`](/mojo/std/os/pathlike/PathLike) trait. A struct conforms to the `PathLike` trait by implementing the `__fspath__()` function. * The [`pathlib.Path`](/mojo/std/pathlib/path/Path) now has functions to query properties of the path. * The [`listdir()`](/mojo/std/pathlib/path/Path#listdir) method now exists on [`pathlib.Path`](/mojo/std/pathlib/path) and also exists in the `os` module to work on `PathLike` structs. For example, the following sample lists all the directories in the `/tmp` directory: ```mojo from pathlib import Path fn walktree(top: Path, inout files: DynamicVector[Path]): try: var ls = top.listdir() for i in range(len(ls)): var child = top / ls[i] if child.is_dir(): walktree(child, files) elif child.is_file(): files.append(child) else: print("Skipping '" + str(child) + "'") except: return fn main(): var files = DynamicVector[Path]() walktree(Path("/tmp"), files) for i in range(len(files)): print(files[i]) ``` * The [`find()`](/mojo/std/builtin/string_literal/StringLiteral#find), [`rfind()`](/mojo/std/builtin/string_literal/StringLiteral#rfind), [`count()`](/mojo/std/collections/string/string_slice/StringSlice#count), and [`__contains__()`](/mojo/std/builtin/string_literal/StringLiteral#__contains__) methods now work on string literals. This means that you can write: ```mojo if "Mojo" in "Hello Mojo": ... ``` * Breakpoints can now be inserted programmatically within the code using the builtin [`breakpoint()`](/mojo/std/builtin/breakpoint/breakpoint) function. Note: on Graviton instances, the debugger might not be able to resume after hitting this kind of breakpoint. * Added a builtin [`Boolable`](/mojo/std/builtin/bool/Boolable) trait that describes a type that can be represented as a boolean value. To conform to the trait, a type must implement the `__bool__()` method. * Modules within packages can now use purely relative `from` imports: ```mojo from . import another_module ``` * Trivial types, like MLIR types and function types, can now be bound implicitly to traits that require copy constructors or move constructors, such as [`Movable`](/mojo/std/builtin/value/Movable), [`Copyable`](/mojo/std/builtin/value/Copyable), and [`CollectionElement`](/mojo/std/builtin/value/CollectionElement). * A new magic `__origin_of(expr)` call will yield the lifetime of a memory value. We hope and expect that this will eventually be replaced by `Reference(expr).lifetime` as the parameter system evolves, but this is important in the meantime for use in function signatures. * A new magic `__type_of(expr)` call will yield the type of a value. This allows one to refer to types of other variables. For example: ```mojo fn my_function(x: Int, y: __type_of(x)) -> Int: let z: __type_of(x) = y return z ``` ### 🦋 Changed * As another step towards [removing let declarations](https://github.com/modular/modular/blob/main/mojo/proposals/remove-let-decls.md) we have removed support for let declarations inside the compiler. To ease migration, we parse `let` declarations as a `var` declaration so your code won't break. We emit a warning about this, but please switch your code to using `var` explicitly, because this migration support will be removed in a subsequent update. ```mojo fn test(): # treated as a var, but please update your code! let x = 42 # warning: 'let' is being removed, please use 'var' instead x = 9 ``` * It is no longer possible to explicitly specify implicit argument parameters in [automatically parameterized functions](/mojo/manual/parameters/#automatic-parameterization). This ability was an oversight and this is now an error: ```mojo fn autoparameterized(x: SIMD): pass autoparameterized[DType.int32, 1](3) # error: too many parameters ``` * `vectorize_unroll` has been removed, and [`vectorize`](/mojo/std/algorithm/functional/vectorize) now has a parameter named `unroll_factor` with a default value of 1. Increasing `unroll_factor` may improve performance at the cost of binary size. See the [loop unrolling blog here](https://www.modular.com/blog/what-is-loop-unrolling-how-you-can-speed-up-mojo) for more details. * The `vectorize` signatures have changed with the closure `func` moved to the first parameter: ```mojo vectorize[func, width, unroll_factor = 1](size) vectorize[func, width, size, unroll_factor = 1]() ``` The doc string has been updated with examples demonstrating the difference between the two signatures. * The `unroll` signatures have changed with the closure `func` moved to the first parameter: ```mojo unroll[func, unroll_count]() ``` * The signature of the `NDBuffer` and `Buffer` types have changed. Now, both take the type as the first parameter and no longer require the shape parameter. This allows you to use these types and have sensible defaults. For example: ```mojo NDBuffer[DType.float32, 3] ``` is equivalent to ```mojo NDBuffer[DType.float32, 3, DimList.create_unknown[3]()] ``` Users can still specify the static shape (if known) to the type: ```mojo NDBuffer[DType.float32, 3, DimList(128, 128, 3)] ``` * The error message for missing function arguments is improved: instead of describing the number of arguments (e.g. `callee expects at least 3 arguments, but 1 was specified`) the missing arguments are now described by name (e.g. `missing 2 required positional arguments: 'b', 'c'`). * The [`CollectionElement`](/mojo/std/builtin/value/CollectionElement) trait is now a built-in trait and has been removed from `collections.vector`. * The `DynamicVector(capacity: Int)` constructor has been changed to take `capacity` as a keyword-only argument to prevent implicit conversion from `Int`. * [`Variant.get[T]()`](/mojo/std/utils/variant/Variant#__getitem__) now returns a `Reference` to the value rather than a copy. * The [`String`](/mojo/std/collections/string/string/String) methods `tolower()` and `toupper()` have been renamed to `str.lower()` and `str.upper()`. * The `ref` and `mutref` identifiers are no longer reserved as Mojo keywords. We originally thought about using those as language sugar for references, but we believe that generic language features combined with the [`Reference`](/mojo/std/memory/pointer/Pointer) type will provide a good experience without dedicated sugar. ### 🛠️ Fixed * [#435](https://github.com/modular/modular/issues/435) Structs with Self type don't always work. * [#1540](https://github.com/modular/modular/issues/1540) Crash in register\_passable self referencing struct. * [#1664](https://github.com/modular/modular/issues/1664) - Improve error message when `StaticTuple` is constructed with a negative size for the number of elements. * [#1679](https://github.com/modular/modular/issues/1679) - crash on SIMD of zero elements. * Various crashes on invalid code: [#1230](https://github.com/modular/modular/issues/1230), [#1699](https://github.com/modular/modular/issues/1699), [#1708](https://github.com/modular/modular/issues/1708) * [#1223](https://github.com/modular/modular/issues/1223) - Crash when parametric function is passed as (runtime) argument. The parser now errors out instead. * [#1530](https://github.com/modular/modular/issues/1530) - Crash during diagnostic emission for parameter deduction failure. * [#1538](https://github.com/modular/modular/issues/1538) and [#1607](https://github.com/modular/modular/issues/1607) - Crash when returning type value instead of instance of expected type. This is a common mistake and the error now includes a hint to point users to the problem. * [#1613](https://github.com/modular/modular/issues/1613) - Wrong type name in error for incorrect `self` argument type in trait method declaration. * [#1670](https://github.com/modular/modular/issues/1670) - Crash on implicit conversion in a global variable declaration. * [#1741](https://github.com/modular/modular/issues/1741) - Mojo documentation generation doesn't show `inout`/`owned` on variadic arguments. * [#1621](https://github.com/modular/modular/issues/1621) - VS Code does not highlight `raises` and `capturing` in functional type expressions. * [#1617](https://github.com/modular/modular/issues/1617) - VS Code does not highlight `fn` in specific contexts. * [#1740](https://github.com/modular/modular/issues/1740) - LSP shows unrelated info when hovering over a struct. * [#1238](https://github.com/modular/modular/issues/1238) - File shadows Mojo package path. * [#1429](https://github.com/modular/modular/issues/1429) - Crash when using nested import statement. * [#1322](https://github.com/modular/modular/issues/1322) - Crash when missing types in variadic argument. * [#1314](https://github.com/modular/modular/issues/1314) - Typecheck error when binding alias to parametric function with default argument. * [#1248](https://github.com/modular/modular/issues/1248) - Crash when importing from file the same name as another file in the search path. * [#1354](https://github.com/modular/modular/issues/1354) - Crash when importing from local package. * [#1488](https://github.com/modular/modular/issues/1488) - Crash when setting generic element field. * [#1476](https://github.com/modular/modular/issues/1476) - Crash in interpreter when calling functions in parameter context. * [#1537](https://github.com/modular/modular/issues/1537) - Crash when copying parameter value. * [#1546](https://github.com/modular/modular/issues/1546) - Modify nested vector element crashes parser. * [#1558](https://github.com/modular/modular/issues/1558) - Invalid import causes parser to crash. * [#1562](https://github.com/modular/modular/issues/1562) - Crash when calling parametric type member function. * [#1577](https://github.com/modular/modular/issues/1577) - Crash when using unresolved package as a variable. * [#1579](https://github.com/modular/modular/issues/1579) - Member access into type instances causes a crash. * [#1602](https://github.com/modular/modular/issues/1602) - Interpreter failure when constructing strings at compile time. * [#1696](https://github.com/modular/modular/issues/1696) - Fixed an issue that caused syntax highlighting to occasionally fail. * [#1549](https://github.com/modular/modular/issues/1549) - Fixed an issue when the shift amount is out of range in `SIMD.shift_left` and `SIMD.shift_right`. ## v0.7.0 (2024-01-25) ### ⭐️ New * A new Mojo-native dictionary type, [`Dict`](/mojo/std/collections/dict) for storing key-value pairs. `Dict` stores values that conform to the [`CollectionElement`](/mojo/std/builtin/value/CollectionElement) trait. Keys need to conform to the new [`KeyElement`](/mojo/std/collections/dict/#keyelement) trait, which is not yet implemented by other standard library types. In the short term, you can create your own wrapper types to use as keys. For example, the following sample defines a `StringKey` type and uses it to create a dictionary that maps strings to `Int` values: ```mojo from collections.dict import Dict, KeyElement @value struct StringKey(KeyElement): var s: String fn __init__(inout self, owned s: String): self.s = s ^ fn __init__(inout self, s: StringLiteral): self.s = String(s) fn __hash__(self) -> Int: return hash(self.s) fn __eq__(self, other: Self) -> Bool: return self.s == other.s def main(): var d = Dict[StringKey, Int]() d["cats"] = 1 d["dogs"] = 2 print(len(d)) # prints 2 print(d["cats"]) # prints 1 print(d.pop("dogs")) # prints 2 print(len(d)) # prints 1 ``` We plan to add `KeyElement` conformance to standard library types in subsequent releases. * Users can opt-in to assertions used in the standard library code by specifying `-D MOJO_ENABLE_ASSERTIONS` when invoking `mojo` to compile your source file(s). In the case that an assertion is fired, the assertion message will be printed along with the stack trace before the program exits. By default, assertions are *not enabled* in the standard library right now for performance reasons. * The Mojo Language Server now implements the References request. IDEs use this to provide support for **Go to References** and **Find All References**. A current limitation is that references outside of the current document are not supported, which will be addressed in the future. * The [`sys.info`](/mojo/std/sys/info) module now includes `num_physical_cores()`, `num_logical_cores()`, and `num_performance_cores()` functions. * Homogeneous variadic arguments consisting of memory-only types, such as `String` are more powerful and easier to use. These arguments are projected into a [`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem). (Previous releases made it easier to use variadic lists of register-passable types, like `Int`.) Subscripting into a `VariadicListMem` now returns the element instead of an obscure internal type. In addition, we now support `inout` and `owned` variadic arguments: ```mojo fn make_worldly(inout *strs: String): # This "just works" as you'd expect! for i in range(len(strs)): strs[i] += " world" fn main(): var s1: String = "hello" var s2: String = "konnichiwa" var s3: String = "bonjour" make_worldly(s1, s2, s3) print(s1) # hello world print(s2) # konnichiwa world print(s3) # bonjour world ``` (Previous releases made it easier to use variadic lists, but subscripting into a `VariadicListMem` returned a low-level pointer, which required the user to call `__get_address_as_lvalue()` to access the element.) Note that subscripting the variadic list works nicely as above, but iterating over the variadic list directly with a `for` loop produces a `Reference` (described below) instead of the desired value, so an extra subscript is required; We intend to fix this in the future. ```mojo fn make_worldly(inout *strs: String): # Requires extra [] to dereference the reference for now. for i in strs: i[] += " world" ``` Heterogeneous variadic arguments have not yet been moved to the new model, but will in future updates. Note that for variadic arguments of register-passable types like `Int`, the variadic list contains values, not references, so the dereference operator (`[]`) is not required. This code continues to work as it did previously: ```mojo fn print_ints(*nums: Int): for num in nums: print(num) print(len(nums)) ``` * Mojo now has a prototype version of a safe [`Reference`](/mojo/std/memory/pointer/Pointer) type. The compiler's lifetime tracking pass can reason about references to safely extend local variable lifetime, and check indirect access safety. The `Reference` type is brand new (and currently has no syntactic sugar) so it must be explicitly dereferenced with an empty subscript: `ref[]` provides access to the underlying value. ```mojo fn main(): var a: String = "hello" var b: String = " references" var aref = Reference(a) aref[] += b print(a) # prints "hello references" aref[] += b # ^last use of b, it is destroyed here. print(aref[]) # prints "hello references references" # ^last use of a, it is destroyed here. ``` While the `Reference` type has the same in-memory representation as a C pointer or the Mojo `Pointer` type, it also tracks a symbolic "lifetime" value so the compiler can reason about the potentially accessed set of values. This lifetime is part of the static type of the reference, so it propagates through generic algorithms and abstractions built around it. The `Reference` type can form references to both mutable and immutable memory objects, e.g. those on the stack or borrowed/inout/owned function arguments. It is fully parametric over mutability, eliminating the [problems with code duplication due to mutability specifiers](https://duckki.github.io/2024/01/01/inferred-mutability.html) and provides the base for unified user-level types. For example, it could be used to implement an array slice object that handles both mutable and immutable array slices. While this is a major step forward for the lifetimes system in Mojo, it is still *very* early and awkward to use. Notably, there is no syntactic sugar for using references, such as automatic dereferencing. Several aspects of it need to be more baked. It is getting exercised by variadic memory arguments, which is why they are starting to behave better now. Note: the safe `Reference` type and the unsafe pointer types are defined in the same module, currently named `memory.unsafe`. We expect to restructure this module in a future release. * Mojo now allows types to implement `__refattr__()` and `__refitem__()` to enable attribute and subscript syntax with computed accessors that return references. For common situations where these address a value in memory this provides a more convenient and significantly more performant alternative to implementing the traditional get/set pairs. Note: this may be changed in the future when references auto-dereference—at that point we may switch to just returning a reference from `__getattr__()`. * Parametric closures can now capture register passable typed values by copy using the `__copy_capture` decorator. For example, the following code will print `5`, not `2`. ```mojo fn foo(x: Int): var z = x @__copy_capture(z) @parameter fn formatter() -> Int: return z z = 2 print(formatter()) fn main(): foo(5) ``` * String now implements KeyElement and may be used as a key in Dict. * More robust support for structs with fields of self referencing types. For example, the following code will work and print `0`: ```mojo struct Foo(CollectionElement): var vec: DynamicVector[Self] fn __init__(inout self: Self): self.vec = DynamicVector[Self]() fn __moveinit__(inout self: Self, owned existing: Self): self.vec = existing.vec ^ fn __copyinit__(inout self: Self, existing: Self): self.vec = existing.vec fn main(): var foo = Foo() print(len(foo.vec)) ``` ### ❌ Removed * The `__takeinit__` special constructor form has been removed from the language. This "non-destructive move" operation was previously wired into the `x^` transfer operator, but had unpredictable behavior that wasn't consistent. Now that Mojo has traits, it is better to model this as an explicit `.take()` operation on a type, which would transfer out the contents of the type without ending its lifetime. For example, for a type that holds a pointer, `take()` might return a new instance pointing to the same data, and null out its own internal pointer. This change makes it clear when a lifetime is ended versus when the contents of an LValue are explicitly taken. * The current implementation of autotuning has been deprecated, as Mojo's autotuning implementation is undergoing a redesign. Tutorials around the current implementation have also been removed as they are being rewritten. Consequently, the `autotune()`, `autotune_fork()`, and `search()` functions have been removed from the standard library. * The `_OldDynamicVector` type that worked only on register passable element types has been removed. Please migrate uses to [`DynamicVector`](/mojo/std/collections/list/List) which works on both register passable and memory types. * The `UnsafeFixedVector` in `utils.vector` has been removed. We recommend using either [`DynamicVector`](/mojo/std/collections/list/List) or [`InlinedFixedVector`](/mojo/std/collections/inline_array/InlineArray) instead. * The `@adaptive` decorator has been removed from the language. Any uses of the decorator in a non-search context can be replaced with `@parameter if`. For example: ```mojo @adaptive fn foo[a: Bool](): constrained[a]() body1() @adaptive fn foo[a: Bool](): constrained[not a]() body2() ``` Can be rewritten as: ```mojo fn foo[a: Bool](): @parameter if a: body1() else: body2() ``` Consequently, the special `__adaptive_set` attribute has been removed as well. * Result parameters have been removed from Mojo. Result parameter declarations in function parameter lists are no longer allowed, nor are forward alias declarations. This includes removing the `param_return` statement. * The `@noncapturing` and `@closure` decorators have been removed due to refinements and improvements to the closure model. See below for more details! ### 🦋 Changed * The Mojo closure model has been refined to be more straightforward and safe. Mojo has two closure types: parameter closures and runtime closures. Parameter closures can be used in higher-order functions and are the backbone of functions like `vectorize` and `parallelize`. They are always denoted by `@parameter` and have type `fn() capturing -> T` (where `T` is the return type). On the other hand, runtime closures are always dynamic values, capture values by invoking their copy constructor, and retain ownership of their capture state. You can define a runtime closure by writing a nested function that captures values: ```mojo fn outer(b: Bool, x: String) -> fn() escaping -> None: fn closure(): print(x) # 'x' is captured by calling String.__copyinit__ fn bare_function(): print("hello") # nothing is captured if b: # closure can be safely returned because it owns its state return closure^ # function pointers can be converted to runtime closures return bare_function ``` The type of runtime closures are of the form `fn() escaping -> T`. You can pass equivalent function pointers as runtime closures. Stay tuned for capture list syntax for move capture and capture by reference, and a more unified closure model! * The `@unroll(n)` decorator can now take a parameter expression for the unroll factor, i.e. `n` can be a parameter expression that is of integer type. * The `cpython` module in the `python` package has been moved to be an internal module, i.e, `_cpython`. * `AnyType` and `Destructable` have been unified into a single trait, `AnyType`. Every nominal type (i.e. all structs) now automatically conform to `AnyType`. * Previously, the `mojo package` command would output a Mojo package that included both partly-compiled Mojo code, as well as fully-compiled machine code for a specific computer architecture -- the architecture of the machine being used to invoke the `mojo package` command. Now, `mojo package` only includes partly-compiled Mojo code. It is only fully compiled for the specific computer architecture being used at the point that the package is first `import`-ed. As a result, Mojo packages are smaller and more portable. * The `simd_width` and `dtype` parameters of `polynomial_evaluate` have been switched. Based on the request in [#1587](https://github.com/modular/modular/issues/1587), the `polynomial_evaluate` function has also been extended so that the `coefficients` parameter can take either a either a [`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple) or a [`VariadicList`](/mojo/std/builtin/variadics/VariadicList). * As a tiny step towards removing `let` declarations, this release removes the warning: `'var' was never mutated, consider switching to a 'let'`. ### 🛠️ Fixed * [#1595](https://github.com/modular/modular/issues/1595) - Improve error message when trying to materialize `IntLiteral` in runtime code. * Raising an error from the initializer of a memory-only type now works correctly in the presence of complex control flow. Previously Mojo could run the destructor on `self` before it was initialized when exiting with an error. * [#1096](https://github.com/modular/modular/issues/1096) - Improve warning messages for dead code in conditionals like `or` expressions. * [#1419](https://github.com/modular/modular/issues/1419) - Fix assertion failure with uninitialized lattice values. * [#1402](https://github.com/modular/modular/issues/1402) - Fix movable trait not detected on recursive struct implemented with `AnyPointer`. * [#1399](https://github.com/modular/modular/issues/1399) - Fix parser crash when a parameter type in a struct that implements a trait is misspelled. * [#1152](https://github.com/modular/modular/issues/1152) - Allow mutable `self` argument when overloading operators using dunder methods. * [#1493](https://github.com/modular/modular/issues/1493) - Fix crash in `DynamicVector` copy constructor in certain situations. * [#1316](https://github.com/modular/modular/issues/1316) - The `benchmark.keep` function now properly handles vector types. * [#1505](https://github.com/modular/modular/issues/1505) - The `simd.shuffle` operation now works on 64 element permutations. * [#1355](https://github.com/modular/modular/issues/1355) - Fix `String.find()` returning wrong value when starting index is non-zero. * [#1367](https://github.com/modular/modular/issues/1367) - Fix `String.replace()` returning incorrect results for multi-character search strings. * [#1535](https://github.com/modular/modular/issues/1535) - Invalid error `field 'w.x.y' destroyed out of the middle of a value, preventing the overall value from being destroyed`. * [#1475](https://github.com/modular/modular/issues/1475) - Assertion failure in nested loop. * [#1591](https://github.com/modular/modular/issues/1591) - Assertion failure when using `AnyType` struct member. * [#1503](https://github.com/modular/modular/issues/1503) - Rename the mojo build of LLDB to `mojo-lldb`, to prevent name collisions with the system's LLDB. * [#1542](https://github.com/modular/modular/issues/1542) - `@unroll` does not accept alias as unroll factor. * [#1443](https://github.com/modular/modular/issues/1443) - Compiler crash on variadic list of traits. * [#1604](https://github.com/modular/modular/issues/1604) - Variable of trivial type not destroyed by transferring ownership. * [#1341](https://github.com/modular/modular/issues/1341) - Segmentation fault when passing closures around. * [#217](https://github.com/modular/modular/issues/217) - Closure state is stack allocated. ## v0.6.1 (2023-12-18) ### ⭐️ New * The Mojo REPL now provides limited support for the `%cd` magic command. This command automatically maintains an internal stack of directories you visit during the REPL session. Usage: * `%cd 'dir'`: change to directory `dir` and push it on the directory stack. * `%cd -`: pop the directory stack and change to the last visited directory. * Structs decorated with `@value` now automatically conform to the [`Movable`](/mojo/std/builtin/value/Movable) and [`Copyable`](/mojo/std/builtin/value/Copyable) built-in traits. * [`String`](/mojo/std/collections/string/string/String) now has new [`toupper()`](/mojo/std/collections/string/string/String#upper) and [`tolower()`](/mojo/std/collections/string/string/String#lower) methods analogous, respectively, to Python's `str.toupper()` and `str.tolower()`. * Added a [`hash()`](/mojo/std/hashlib/hash/hash) built-in function and [`Hashable`](/mojo/std/hashlib/hash/Hashable) trait for types implementing the `__hash__()` method. Future releases will add `Hashable` support to Standard Library types. In the meantime, the `hash` module includes a version of the `hash()` function that works on arbitrary byte strings. To generate hashes for [`SIMD`](/mojo/std/builtin/simd/SIMD) types, you use the internal `_hash_simd()` function: ```mojo from builtin.hash import _hash_simd fn gen_simd_hash(): let vector = SIMD[DType.int64, 4](1, 2, 3, 4) let hash = _hash_simd(vector) ``` * Several standard library types now conform to the [`CollectionElement`](/mojo/std/builtin/value/CollectionElement) trait. These types include [`Bool`](/mojo/std/builtin/bool/Bool), [`StringLiteral`](/mojo/std/builtin/string_literal/StringLiteral), [`DynamicVector`](/mojo/std/collections/list/List), `Tensor`, `TensorShape`, and `TensorSpec`. ### 🦋 Changed * `utils.vector` has been moved to a new `collections` package to make space for new collections. This means that if you had previous code that did `from utils.vector import DynamicVector`, it now needs to be `from collections.vector import DynamicVector` due to the move. * The special destructor method `__del__()` has been changed to enforce that it cannot raise an error. Raising destructors are not supported properly at the moment. ### 🛠️ Fixed * [#1421](https://github.com/modular/modular/issues/1421) - Fixed a crash when using Tuples in the REPL. * [#222](https://github.com/modular/modular/issues/222) - Generate an error for obviously self recursive functions. * [#1408](https://github.com/modular/modular/issues/1408) - Fix overload resolution when candidates can return generic types. * [#1413](https://github.com/modular/modular/issues/1413) and [#1395](https://github.com/modular/modular/issues/1395) - Do not crash when re-declaring a builtin declaration. * [#1307](https://github.com/modular/modular/issues/1307) - Fix compatibility of function signatures that only differ in default argument values. * [#1380](https://github.com/modular/modular/issues/1380) - Fix printing of empty `String`. ## v0.6.0 (2023-12-04) ### 🔥 Legendary * Traits have arrived! You can now define a *trait*, which consists of a required set of method prototypes. A struct can *conform to* the trait by implementing these methods. This lets you write generic functions that work on any structs that conform to a given trait. The following section gives a brief overview of traits—see the [Mojo Manual](/mojo/manual/traits) and this [traits blog post](https://modul.ar/traits-blog) for more details! Traits are declared with the `trait` keyword. The bodies of traits should contain method signatures declared with `...` as their bodies. Default method implementations are not supported yet. ```mojo trait SomeTrait: fn required_method(self, x: Int): ... ``` The trait can be implemented on a struct by inheriting from it. ```mojo struct SomeStruct(SomeTrait): fn required_method(self, x: Int): print("hello traits", x) ``` You can then write a generic functions that accepts any type that conforms to the trait. You do this by creating a parameterized function with a trait-typed parameter: ```mojo fn fun_with_traits[T: SomeTrait](x: T): x.required_method(42) ``` Which can be invoked with instances of types that conform to the trait: ```mojo var thing = SomeStruct() # Infer the parameter `T`! fun_with_traits(thing) ``` Traits can also inherit from other traits, which simply requires that implementers of the child trait also conform to all parent traits. ```mojo trait Parent: fn parent_func(self): ... trait Child(Parent): fn child_func(self): ... ``` Then, both child and parent trait methods can be invoked on instances of the trait `Child`. As well, an instance of the child trait can be converted to an instance of the parent trait. ```mojo fn the_parents[T: Parent](x: T): x.parent_func() fn the_children[T: Child](x: T): x.child_func() x.parent_func() # Upcast `x` from instance of `Child` to `Parent`. the_parents(x) ``` For more information, see the [Traits page](/mojo/manual/traits) in the Mojo Manual. * A fundamental `Destructable` trait has been added to the language. This is a core trait that every trait automatically conforms to. This enables destruction of generic types and generic collections. **Note:** We're aware that this trait might be better spelled `Destructible`. We're planning on removing it in the future and moving its functionality to `AnyType` so that any type that doesn't provide its own destructor will have a default, no-op destructor. * We've added some traits to the standard library, you can implement these on your own types: * [`Destructable`](/mojo/std/builtin/anytype/AnyType) * [`Copyable`](/mojo/std/builtin/value/Copyable) * [`Movable`](/mojo/std/builtin/value/Movable) * [`Stringable`](/mojo/std/builtin/str/Stringable) * [`Intable`](/mojo/std/builtin/int/Intable) * [`Sized`](/mojo/std/builtin/len/Sized) * [`CollectionElement`](/mojo/std/builtin/value/CollectionElement) * We added built-in [`len()`](/mojo/std/builtin/len/len), `str()`, and `int()` functions, which work with types that implement the `Sized`, `Stringable`, and `Intable` traits, respectively. * [`DynamicVector`](/mojo/std/collections/list/List) is now a proper generic collection that can use any type that implements the `Movable` and `Copyable` traits. This means you can now write, for example, `DynamicVector[String]`. Also, `DynamicVector` now invokes its element destructors upon destruction, so `_del_old` has been deleted. * `print` now works on any types that implement `Stringable` by invoking their `__str__` method: ```mojo @value struct BoxedInt(Stringable): var value: Int fn __str__(self) -> String: return self.value print(BoxedInt(11), "hello traits!", BoxedInt(42)) ``` ### ⭐️ New * The [Mojo Manual](/mojo/manual/) is an all-new, complete Mojo user guide. It doesn't include *everything* about Mojo yet, but it includes a lot, and more than the original programming manual (now deprecated). Plus, the entire Mojo Manual and other Mojo docs are now [open-sourced on GitHub](https://github.com/modular/modular/tree/main/mojo/docs), and we'd love to accept contributions to help us improve them! * Mojo now supports partial automatic parameterization: when a function is declared with an argument of a partially bound type, the unbound parameters of that type are implicitly added to the function's input parameters. For example: ```mojo @value struct Fudge[a: Int, b: Int, c: Int = 7]: ... # These function declarations are roughly equivalent: fn eat(f: Fudge[5]): ... # implicitly parameterized fn eat[_b: Int](f: Fudge[5, _b]): ... # explicitly parameterized ``` In the first signature for `eat()`, the `b` parameter isn't bound, so it's *implicitly* added as an input parameter on the function. In the second signature for `eat()`, the author has explicitly defined an input parameter (`_b`), which is bound to the second parameter on the argument type (which happens to be `b`). Both functions can be called like this: ```mojo eat(Fudge[5, 8]()) ``` Mojo infers the value of the `b` parameter from the argument (in this case, 8\). With the second signature, you can also pass the `_b` parameter value explicitly: ```mojo eat[3](Fudge[5, 3]()) ``` Moreover, Mojo now allows you to explicitly mark parameters as unbound using the `_` as syntax meaning "placeholder for an unbound parameter." For example: ```mojo # These function declarations are roughly equivalent: fn eat(f: Fudge[5, _, c=_]): ... # implicitly parameterized fn eat(f: Fudge[c=_, a=5, b=_]): ... # implicitly parameterized fn eat[_b: Int, _c: Int](f: Fudge[5, _b, _c]): ... # explicitly parameterized ``` The first two signatures explicitly unbind the `b` and `c` parameters. In the last signature, the `_b` and `_c` parameters are explicitly declared by the author, and bound to the `b` and `c` parameters in the argument type. Any of these signatures can be called like this: ```mojo eat(Fudge[5, 8]()) eat(Fudge[5, 8, 9]()) ``` Note that the default parameter values of struct parameters are bound, unless explicitly unbound by the user. For more information, see the [Mojo Manual](/mojo/manual/parameters/#fully-bound-partially-bound-and-unbound-types). * Parametric types can now be partially bound in certain contexts. For example, a new `Scalar` type alias has been added defined as: ```mojo alias Scalar = SIMD[size=1] ``` Which creates a parametric type alias `Scalar` with a single parameter of type `DType`. Types can also be partially or fully bound in other contexts. For instance, `alias` declarations of type values inside functions now work properly: ```mojo fn type_aliases(): alias T = SIMD print(T[DType.float32, 1]()) alias Partial = T[type=DType.int32] print(Partial[2]()) ``` * The `__mlir_op` feature now supports operations that return multiple results. To use them, you write the `_type` field as a `Tuple` of types. For example: ```mojo # The `ret` variable has type `Tuple[Int, Int]`. let ret = __mlir_op.`multi_result_op`[_type=(Int, Int)]() ``` * Mojo now has the ability to read raw bytes from a file using the [`read_bytes()`](/mojo/std/builtin/file/FileHandle#read_bytes) method. For example: ```mojo with open("file.binary", "r") as f: data = f.read_bytes() ``` * A size argument was added to the [`read()`](/mojo/std/builtin/file/FileHandle#read) and [`read_bytes()`](/mojo/std/builtin/file/FileHandle#read_bytes) methods on the builtin `file.FileHandle`. The size argument defaults to -1 and maintains the previous "read to EOF" behavior when size is negative. ```mojo with open("file.binary", "r") as f: data1 = f.read_bytes(1024) data2 = f.read_bytes(256) ``` * [`Path`](/mojo/std/pathlib/path/Path) now has `read_bytes()` and `read_text()` methods to read file contents from a path: ```mojo let text_path = Path("file.txt") let text = text_path.read_text() let binary_path = Path("file.binary") let data = binary_path.read_bytes() ``` * `Tensor` has new `save()` and `load()` methods to save and load to file. These methods preserve shape and datatype information. For example: ```mojo let tensor = Tensor[DType.float32]() tensor.save(path) let tensor_from_file = Tensor[DType.float32].load(path) ``` * Subscripting added to [`DTypePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) and [`Pointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer): ```mojo let p = DTypePointer[DType.float16].alloc(4) for i in range(4): p[i] = i print(p[i]) ``` * `file.FileHandle` now has a `seek()` method. * [`String`](/mojo/std/collections/string/string/String) now has an [`rfind()`](/mojo/std/collections/string/string/String#rfind) method analogous to Python's `str.rfind()`. * `String` now has an [`split()`](/mojo/std/collections/string/string/String#split) method analogous to Python's `str.split()`. * [`Path`](/mojo/std/pathlib/path/Path) now has a [`suffix()`](/mojo/std/pathlib/path/Path#suffix) method analogous to Python's `pathlib.Path.suffix`. * The Mojo REPL now supports indented expressions, making it a bit easier to execute expressions copied from an indented block (such as a doc string). * The Mojo Language Server now implements the Document Symbols request. IDEs use this to provide support for **Outline View** and **Go to Symbol**. This addresses [Issue #960](https://github.com/modular/modular/issues/960). * The Mojo Language Server now shows documentation when code completing modules or packages in `import` statements. * The Mojo Language Server now supports processing code examples, defined as markdown Mojo code blocks, inside of doc strings. This enables IDE features while writing examples in API documentation. * The Mojo Language Server now provides semantic token information, providing better highlighting for symbols whose semantics are not statically analyzable. * The Mojo Language Server now classifies doc strings as folding ranges, making them easier to collapse, reducing vertical space while editing. * Command line options for the `mojo` driver that take arguments can now be written in either of two ways: both `--foo FOO` and `--foo=FOO`. Previously, only the former was valid. ### 🦋 Changed * Variadic list types [`VariadicList`](/mojo/std/builtin/variadics/VariadicList) and [`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem) are now iterable. Variadic arguments are automatically projected into one of these types inside the function body, so var args can be iterated: ```mojo fn print_ints(*nums: Int): for num in nums: print(num) print(len(nums)) ``` * The assert functions in the [`testing`](/mojo/std/testing/testing) package now raise an `Error` when the assertion fails instead of returning a `Bool` for whether the assertion succeeded or not. * Parameters of [`AnyType`](/mojo/std/builtin/type_aliases) type are no longer (implicitly) assumed to be register-passable. A new `AnyRegType` type is used to represent generic types that are register passable. * Changing the units in a [`benchmark`](/mojo/std/benchmark/benchmark) report is now an argument instead of a parameter: ```mojo let report = benchmark.run[timer]() report.print(Unit.ms) ``` * Default values on `inout` arguments are no longer permitted, i.e. the following will now raise an error: ```mojo fn inout_default(inout x: Int = 2): ... ``` * The `to_string()` function has been removed from [`PythonObject`](/mojo/std/python/python_object/PythonObject) in favor of the new `__str__()` function. This composes better with traits so it can be used with the generic `str()` function. ### 🛠️ Fixed * [#734](https://github.com/modular/modular/issues/734) - Consumption of struct works only for types with a `__del__` method. * [#910](https://github.com/modular/modular/issues/910) - Parser crash when using memory-only generic type as return of function that `raise`s. * [#1060](https://github.com/modular/modular/issues/1060) - Mojo happily parses code that has messed up indentation * [#1159](https://github.com/modular/modular/issues/1159) - The language server doesn't warn about bad return type. * [#1166](https://github.com/modular/modular/issues/1166) - warning: unreachable code after return statement with context manager * [#1098](https://github.com/modular/modular/issues/1098) - The language server doesn't highlight properties of PythonObjects correctly. * [#1153](https://github.com/modular/modular/issues/1153) - The language server crashes when parsing an invalid multi-nested module import. * [#1236](https://github.com/modular/modular/issues/1236) - The language server doesn't show autocomplete in if statements. * [#1246](https://github.com/modular/modular/issues/1246) - Warning diagnostics are transient in the presence of caching. ### Known Issue * There is an issue affecting Jupyter notebooks that use autotuning and traits. This issue only manifests on macOS, and the same code runs without issue outside of the notebooks. This issue affects the *Matrix multiplication in Mojo* notebook. ## v0.5.0 (2023-11-2) ### ⭐️ New * The [`SIMD`](/mojo/std/builtin/simd/SIMD) type now defaults to the architectural SIMD width of the type. This means you can write `SIMD[DType.float32]` which is equivalent to `SIMD[DType.float32, simdwidthof[DType.float32]()]`. * The [`SIMD`](/mojo/std/builtin/simd/SIMD) type now contains a `join()` function that allows you to concatenate two `SIMD` values together and produce a new `SIMD` value. * Mojo now supports compile-time *keyword parameters*, in addition to existing support for [keyword arguments](/mojo/manual/parameters/#optional-parameters-and-keyword-parameters). For example: ```mojo fn foo[a: Int, b: Int = 42](): print(a, "+", b) foo[a=5]() # prints '5 + 42' foo[a=7, b=13]() # prints '7 + 13' foo[b=20, a=6]() # prints '6 + 20' ``` Keyword parameters are also supported in structs: ```mojo struct KwParamStruct[a: Int, msg: String = "🔥mojo🔥"]: fn __init__(inout self): print(msg, a) fn use_kw_params(): KwParamStruct[a=42]() # prints '🔥mojo🔥 42' KwParamStruct[5, msg="hello"]() # prints 'hello 5' KwParamStruct[msg="hello", a=42]() # prints 'hello 42' ``` For more detail, see the [Mojo Manual](/mojo/manual/parameters/#optional-parameters-and-keyword-parameters). For the time being, the following notable limitations apply: * Keyword-only parameters are **not supported** yet: ```mojo fn baz[*args: Int, b: Int](): pass # fails fn baz[a: Int, *, b: Int](): pass # fails ``` (The analogous keyword-only arguments in Python are described in [PEP 3102](https://peps.python.org/pep-3102/).) * Variadic keyword parameters are **not supported** yet: ```mojo fn baz[a: Int, **kwargs: Int](): pass # fails ``` * Mojo now supports "automatic" parameterization of functions. What this means is that if a function argument type is parametric but has no bound parameters, they are automatically added as input parameters on the function. This works with existing features to allow you to write parametric functions with less boilerplate. ```mojo @value struct Thing[x: Int, y: Int]: pass fn foo(v: Thing): print(v.x) print(v.y) fn main(): let v = Thing[2, 3]() foo(v) ``` However, partial autoparameterization is **not supported** yet: ```mojo fn foo(v: Thing[y=7]): # Partially bound type not allowed yet. ... ``` * Keyword argument passing is supported when invoking `__getitem__` using the bracket syntax: ```mojo @value struct MyStruct: fn __getitem__(self, x: Int, y: Int, z: Int) -> Int: return x * y + z MyStruct()[z=7, x=3, y=5] # returns 22 ``` However, keyword argument passing to `__setitem__` using the bracket syntax is **not supported** yet: ```mojo @value struct OtherStruct: fn __setitem__(self, x: Int, y: Int): pass OtherStruct()[x=1] = 4 # fails ``` * Function argument input parameters can now be referenced within the signature of the function: ```mojo fn foo(x: SIMD, y: SIMD[x.type, x.size]): pass ``` * The [`benchmark`](/mojo/std/benchmark/benchmark) module has been simplified and improved so you can now run: ```mojo import benchmark from time import sleep fn sleeper(): sleep(.01) fn main(): let report = benchmark.run[sleeper]() print(report.mean()) ``` It no longer requires a capturing `fn` so can benchmark functions outside the same scope. You can print a report with: ```mojo report.print() ``` ```plaintext --------------------- Benchmark Report (s) --------------------- Mean: 0.012314264957264957 Total: 1.440769 Iters: 117 Warmup Mean: 0.0119335 Warmup Total: 0.023866999999999999 Warmup Iters: 2 Fastest Mean: 0.012227958333333334 Slowest Mean: 0.012442699999999999 ``` Units for all functions default to seconds, but can be changed with: ```mojo from benchmark import Unit report.print[Unit.ms]() ``` * Mojo now supports struct parameter deduction (a.k.a. class template argument deduction, or CTAD) for partially bound types. Struct parameter deduction is also possible from static methods. For example: ```mojo @value struct Thing[v: Int]: pass struct CtadStructWithDefault[a: Int, b: Int, c: Int = 8]: fn __init__(inout self, x: Thing[a]): print("hello", a, b, c) @staticmethod fn foo(x: Thing[a]): print("🔥", a, b, c) fn main(): _ = CtadStructWithDefault[b=7](Thing[6]()) # prints 'hello 6 7 8' CtadStructWithDefault[b=7].foo(Thing[6]()) # prints '🔥 6 7 8' ``` * `Tensor` has new `fromfile()` and `tofile()` methods to save and load as bytes from a file. * The built-in `print()` function now works on the `Tensor` type. * `TensorShape` and `TensorSpec` now have constructors that take [`DynamicVector[Int]`](/mojo/std/collections/list/List) and [`IndexList`](/mojo/std/utils/index_/IndexList) to initialize shapes. * The [`String`](/mojo/std/collections/string/string/String) type now has the `count()` and `find()` methods to enable counting the number of occurrences or finding the offset index of a substring in a string. * The `String` type now has a `replace()` method which allows you to replace a substring with another string. ### 🦋 Changed * [`VariadicList`](/mojo/std/builtin/variadics/VariadicList) and [`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem) moved under builtins, and no longer need to be imported. * Variadic arguments are now automatically projected into a `VariadicList` or `VariadicListMem` inside the function body. This allows for more flexibility in using var args. For example: ```mojo fn print_ints(*nums: Int): let len = len(nums) for i in range(len): print(nums[i]) print(len) ``` * The parameters for [`InlinedFixedVector`](/mojo/std/collections/inline_array/InlineArray) have been switched. The parameters are now `[type, size]` instead of `[size, type]`. The `InlinedFixedVector` now has a default size which means that one can just use `InlinedFixedVector` as `InlinedFixedVector[Float32]` and the default size is used. * `write_file()` method in `Buffer` and `NDBuffer` is renamed to `tofile()` to match the Python naming. * Mojo will now utilize all available cores across all NUMA sockets on the host machine by default. The prior default behavior was to use all the cores on the first socket. ### ❌ Removed * The `math.numerics` module is now private, because its types (`FPUtils` and `FlushDenormals`) should not be used externally. ### 🛠️ Fixed * [#532](https://github.com/modular/modular/issues/532) - Compiler optimizing while True loop away * [#760](https://github.com/modular/modular/issues/760) - Compilation error: 'hlcf.for.yield' op specifies 0 branch inputs but target expected 1 along control-flow edge from here * [#849](https://github.com/modular/modular/issues/849) - The `Tensor` type is now initialized with zeros at construction time. * [#912](https://github.com/modular/modular/issues/912) - Invalid load for `__get_address_as_lvalue`. * [#916](https://github.com/modular/modular/issues/916) - Parser crash when specifying default values for `inout` arguments. * [#943](https://github.com/modular/modular/issues/943) - Mojo hangs if you use continue in the nested loop * [#957](https://github.com/modular/modular/issues/957) - Parser crash when a function call with variadic arguments of a memory-only type is evaluated at compile time. * [#990](https://github.com/modular/modular/issues/990) - Fixes rounding issue with floor division with negative numerator. * [#1018](https://github.com/modular/modular/issues/1018) - In some cases the sort function was returning invalid results. This release fixes some of these corner cases. * [#1010](https://github.com/modular/modular/issues/1010) - Initializing tensor in alias declaration results in crash. * [#1110](https://github.com/modular/modular/issues/1110) - The `time.now()` function now returns nanoseconds across all operating systems. * [#1115](https://github.com/modular/modular/issues/1115) - cannot load non-register passable type into SSA register. ## v0.4.0 for Mac (2023-10-19) ### 🔥 Legendary * Mojo for Mac! The Mojo SDK now works on macOS (Apple silicon). This is the same version previously released for Linux. Get the latest version of the SDK for your Mac system: [Download Now!](https://developer.modular.com/download) ## v0.4.0 (2023-10-05) ### ⭐️ New * Mojo now supports default parameter values. For example: ```mojo fn foo[a: Int = 3, msg: StringLiteral = "woof"](): print(msg, a) fn main(): foo() # prints 'woof 3' foo[5]() # prints 'woof 5' foo[7, "meow"]() # prints 'meow 7' ``` Inferred parameter values take precedence over defaults: ```mojo @value struct Bar[v: Int]: pass fn foo[a: Int = 42, msg: StringLiteral = "quack"](bar: Bar[a]): print(msg, a) fn main(): foo(Bar[9]()) # prints 'quack 9' ``` Structs also support default parameters: ```mojo @value struct DefaultParams[msg: StringLiteral = "woof"]: alias message = msg fn main(): print(DefaultParams[]().message) # prints 'woof' print(DefaultParams["meow"]().message) # prints 'meow' ``` * The new [`file`](/mojo/std/builtin/file) module adds basic file I/O support. You can now write: ```mojo var f = open("my_file.txt", "r") print(f.read()) f.close() ``` or ```mojo with open("my_file.txt", "r") as f: print(f.read()) ``` * Mojo now allows context managers to support an `__enter__` method without implementing support for an `__exit__` method, enabling idioms like this: ```mojo # This context manager consumes itself and returns it as the value. fn __enter__(owned self) -> Self: return self^ ``` Here Mojo *cannot* invoke a noop `__exit__` method because the context manager is consumed by the `__enter__` method. This can be used for types (like file descriptors) that are traditionally used with `with` statements, even though Mojo's guaranteed early destruction doesn't require that. * A very basic version of `pathlib` has been implemented in Mojo. The module will be improved to achieve functional parity with Python in the next few releases. * The `memory.unsafe` module now contains a `bitcast` function. This is a low-level operation that enables bitcasting between pointers and scalars. * The input parameters of a parametric type can now be directly accessed as attribute references on the type or variables of the type. For example: ```mojo @value struct Thing[param: Int]: pass fn main(): print(Thing[2].param) # prints '2' let x = Thing[9]() print(x.param) # prints '9' ``` Input parameters on values can even be accessed in parameter contexts. For example: ```mojo fn foo[value: Int](): print(value) let y = Thing[12]() alias constant = y.param + 4 foo[constant]() # prints '16' ``` * The Mojo REPL now supports code completion. Press Tab while typing to query potential completion results. * Error messages from Python are now exposed in Mojo. For example the following should print `No module named 'my_uninstalled_module'`: ```mojo fn main(): try: let my_module = Python.import_module("my_uninstalled_module") except e: print(e) ``` * Error messages can now store dynamic messages. For example, the following should print "Failed on: Hello" ```mojo fn foo(x: String) raises: raise Error("Failed on: " + x) fn main(): try: foo("Hello") except e: print(e) ``` ### 🦋 Changed * We have improved and simplified the `parallelize` function. The function now elides some overhead by caching the Mojo parallel runtime. * The Mojo REPL and Jupyter environments no longer implicitly expose `Python`, `PythonObject`, or `Pointer`. These symbols must now be imported explicitly, for example: ```mojo from python import Python from python.object import PythonObject from memory.unsafe import Pointer ``` * The syntax for specifying attributes with the `__mlir_op` prefix have changed to mimic Python's keyword argument passing syntax. That is, `=` should be used instead of `:`, e.g.: ```mojo # Old syntax, now fails. __mlir_op.`index.bool.constant`[value : __mlir_attr.false]() # New syntax. __mlir_op.`index.bool.constant`[value=__mlir_attr.false]() ``` * You can now print the `Error` object directly. The `message()` method has been removed. ### 🛠️ Fixed * [#794](https://github.com/modular/modular/issues/794) - Parser crash when using the `in` operator. * [#936](https://github.com/modular/modular/issues/936) - The `Int` constructor now accepts other `Int` instances. * [#921](https://github.com/modular/modular/issues/921) - Better error message when running `mojo` on a module with no `main` function. * [#556](https://github.com/modular/modular/issues/556) - UInt64s are now printed correctly. * [#804](https://github.com/modular/modular/issues/804) - Emit error instead of crashing when passing variadic arguments of unsupported types. * [#833](https://github.com/modular/modular/issues/833) - Parser crash when assigning module value. * [#752](https://github.com/modular/modular/issues/752) - Parser crash when calling async def. * [#711](https://github.com/modular/modular/issues/711) - The overload resolution logic now correctly prioritizes instance methods over static methods (if candidates are an equally good match otherwise), and no longer crashed if a static method has a `Self` type as its first argument. * [#859](https://github.com/modular/modular/issues/859) - Fix confusing error and documentation of the `rebind` builtin. * [#753](https://github.com/modular/modular/issues/753) - Direct use of LLVM dialect produces strange errors in the compiler. * [#926](https://github.com/modular/modular/issues/926) - Fixes an issue that occurred when a function with a return type of `StringRef` raised an error. When the function raised an error, it incorrectly returned the string value of that error. * [#536](https://github.com/modular/modular/issues/536) - Report More information on python exception. ## v0.3.1 (2023-09-28) Our first-ever patch release of the Mojo SDK is here! Release v0.3.1 includes primarily installation-related fixes. If you’ve had trouble installing the previous versions of the SDK, this release may be for you. ### 🛠️ Fixed * [#538](https://github.com/modular/modular/issues/538) - Installation hangs during the testing phase. This issue occurs on machines with a low number of CPU cores, such as free AWS EC2 instances and GitHub Codespaces. * [#590](https://github.com/modular/modular/issues/590) - Installation fails with a “failed to run python” message. * [#672](https://github.com/modular/modular/issues/672) - Language server hangs on code completion. Related to #538, this occurs on machines with a low number of CPU cores. * [#913](https://github.com/modular/modular/issues/913) - In the REPL and Jupyter notebooks, inline comments were being parsed incorrectly. ## v0.3.0 (2023-09-21) There's more Mojo to love in this, the second release of the Mojo SDK! This release includes new features, an API change, and bug fixes. There's also an updated version of the [Mojo extension for VS Code](https://marketplace.visualstudio.com/items?itemName=modular-mojotools.vscode-mojo). ### ⭐️ New * Mojo now has partial support for passing keyword arguments to functions and methods. For example the following should work: ```mojo fn foo(a: Int, b: Int = 3) -> Int: return a * b fn main(): print(foo(6, b=7)) # prints '42' print(foo(a=6, b=7)) # prints '42' print(foo(b=7, a=6)) # prints '42' ``` Parameters can also be inferred from keyword arguments, for example: ```mojo fn bar[A: AnyType, B: AnyType](a: A, b: B): print("Hello 🔥") fn bar[B: AnyType](a: StringLiteral, b: B): print(a) fn main(): bar(1, 2) # prints `Hello 🔥` bar(b=2, a="Yay!") # prints `Yay!` ``` For the time being, the following notable limitations apply: * Keyword-only arguments are not supported: ```mojo fn baz(*args: Int, b: Int): pass # fails fn baz(a: Int, *, b: Int): pass # fails ``` (Keyword-only arguments are described in [PEP 3102](https://peps.python.org/pep-3102/).) * Variadic keyword arguments are not supported: ```mojo fn baz(a: Int, **kwargs: Int): pass # fails ``` * Mojo now supports the `@nonmaterializable` decorator. The purpose is to mark data types that should only exist in the parameter domain. To use it, a struct is decorated with `@nonmaterializable(TargetType)`. Any time the nonmaterializable type is converted from the parameter domain, it is automatically converted to `TargetType`. A nonmaterializable struct should have all of its methods annotated as `@always_inline`, and must be computable in the parameter domain. In the following example, the `NmStruct` type can be added in the parameter domain, but are converted to `HasBool` when materialized. ```mojo @value @register_passable("trivial") struct HasBool: var x: Bool fn __init__(x: Bool) -> Self: return Self {x: x} @always_inline("nodebug") fn __init__(nms: NmStruct) -> Self: return Self {x: True if (nms.x == 77) else False} @value @nonmaterializable(HasBool) @register_passable("trivial") struct NmStruct: var x: Int @always_inline("nodebug") fn __add__(self: Self, rhs: Self) -> Self: return NmStruct(self.x + rhs.x) alias stillNmStruct = NmStruct(1) + NmStruct(2) # When materializing to a run-time variable, it is automatically converted, # even without a type annotation. let convertedToHasBool = stillNmStruct ``` * Mojo integer literals now produce the `IntLiteral` infinite precision integer type when used in the parameter domain. `IntLiteral` is materialized to the `Int` type for runtime computation, but intermediate computations at compile time, using supported operators, can now exceed the bit width of the `Int` type. * The Mojo Language Server now supports top-level code completions, enabling completion when typing a reference to a variable, type, etc. This resolves [#679](https://github.com/modular/modular/issues/679). * The Mojo REPL now colorizes the resultant variables to help distinguish input expressions from the output variables. ### 🦋 Changed * Mojo allows types to implement two forms of move constructors, one that is invoked when the lifetime of one value ends, and one that is invoked if the compiler cannot prove that. These were previously both named `__moveinit__`, with the following two signatures: ```mojo fn __moveinit__(inout self, owned existing: Self): ... fn __moveinit__(inout self, inout existing: Self): ... ``` We've changed the second form to get its own name to make it more clear that these are two separate operations: the second has been renamed to `__takeinit__`: ```mojo fn __moveinit__(inout self, owned existing: Self): ... fn __takeinit__(inout self, inout existing: Self): ... ``` The name is intended to connote that the operation takes the conceptual value from the source (without destroying it) unlike the first one which "moves" a value from one location to another. For more information, see the Mojo Manual section on [move constructors](/mojo/manual/lifecycle/life#move-constructor). * The Error type in Mojo has changed. Instead of extracting the error message using `error.value` you will now extract the error message using `error.message()`. ### 🛠️ Fixed * [#503](https://github.com/modular/modular/issues/503) - Improve error message for failure lowering `kgen.param.constant`. * [#554](https://github.com/modular/modular/issues/554) - Alias of static tuple fails to expand. * [#500](https://github.com/modular/modular/issues/500) - Call expansion failed due to verifier error. * [#422](https://github.com/modular/modular/issues/422) - Incorrect comment detection in multiline strings. * [#729](https://github.com/modular/modular/issues/740) - Improve messaging on how to exit the REPL. * [#756](https://github.com/modular/modular/issues/756) - Fix initialization errors of the VS Code extension. * [#575](https://github.com/modular/modular/issues/575) - Build LLDB/REPL with libedit for a nicer editing experience in the terminal. ## v0.2.1 (2023-09-07) The first versioned release of Mojo! 🔥 All earlier releases were considered version 0.1. ### 🔥 Legendary * First release of the Mojo SDK! You can now develop with Mojo locally. The Mojo SDK is currently available for Ubuntu Linux systems, and support for Windows and macOS is coming soon. You can still develop from a Windows or Mac computer using a container or remote Linux system. The Mojo SDK includes the Mojo standard library and the [Mojo command-line interface](/mojo/cli/) (CLI), which allows you to run, compile, and package Mojo code. It also provides a REPL programming environment. [Get the Mojo SDK!](https://developer.modular.com/download) * First release of the [Mojo extension for VS Code](https://marketplace.visualstudio.com/items?itemName=modular-mojotools.vscode-mojo). This provides essential Mojo language features in Visual Studio Code, such as code completion, code quick fixes, docs tooltips, and more. Even when developing on a remote system, using VS Code with this extension provides a native-like IDE experience. ### ⭐️ New * A new `clobber_memory` function has been added to the [`benchmark`](/mojo/std/benchmark/benchmark) module. The clobber memory function tells the system to flush all memory operations at the specified program point. This allows you to benchmark operations without the compiler reordering memory operations. * A new `keep` function has been added to the [`benchmark`](/mojo/std/benchmark/benchmark) module. The `keep` function tries to tell the compiler not to optimize the variable away if not used. This allows you to avoid compiler's dead code elimination mechanism, with a low footprint side effect. * New `shift_right` and `shift_left` functions have been added to the [`simd`](/mojo/std/builtin/simd) module. They shift the elements in a SIMD vector right/left, filling elements with zeros as needed. * A new `cumsum` function has been added to the [`reduction`](/mojo/std/algorithm/reduction) module that computes the cumulative sum (also known as scan) of input elements. * Mojo Jupyter kernel now supports code completion. ### 🦋 Changed * Extends `rotate_bits_left`, `rotate_left`, `rotate_bits_right`, and `rotate_right` to operate on Int values. The ordering of parameters has also been changed to enable type inference. Now it's possible to write `rotate_right[shift_val](simd_val)` and have the `dtype` and `simd_width` inferred from the argument. This addresses [Issue #528](https://github.com/modular/modular/issues/528). ### 🛠️ Fixed * Fixed a bug causing the parser to crash when the `with` statement was written without a colon. This addresses [Issue #529](https://github.com/modular/modular/issues/529). * Incorrect imports no longer crash when there are other errors at the top level of a module. This fixes [Issue \#531](https://github.com/modular/modular/issues/531). ## August 2023 ### 2023-08-24 * Fixed issue where the `with expr as x` statement within `fn` behaved as if it were in a `def`, binding `x` with function scope instead of using lexical scope. #### ⭐️ New * Major refactoring of the standard library to enable packaging and better import ergonomics: * The packages are built as binaries to improve startup speed. * Package and module names are now lowercase to align with the Python style. * Modules have been moved to better reflect the purpose of the underlying functions (e.g. `Pointer` is now within the `unsafe` module in the `memory` package). * The following modules are now included as built-ins: `SIMD`, `DType`, `IO`, `Object`, and `String`. This means it's no longer necessary to explicitly import these modules. Instead, these modules will be implicitly imported for the user. Private methods within the module are still accessible using the `builtin.module_name._private_method` import syntax. * New `math` package has been added to contain the `bit`, `math`, `numerics`, and `polynomial` modules. The contents of the `math.math` module are re-exported into the `math` package. * Mojo now supports using memory-only types in parameter expressions and as function or type parameters: ```mojo @value struct IntPair: var first: Int var second: Int fn add_them[value: IntPair]() -> Int: return value.first + value.second fn main(): print(add_them[IntPair(1, 2)]()) # prints '3' ``` * In addition, Mojo supports evaluating code that uses heap-allocated memory at compile-time and materializing compile-time values with heap-allocated memory into dynamic values: ```mojo fn fillVector(lowerBound: Int, upperBound: Int, step: Int) -> DynamicVector[Int]: var result = DynamicVector[Int]() for i in range(lowerBound, upperBound, step): result.push_back(i) return result fn main(): alias values = fillVector(5, 23, 7) for i in range(0, values.__len__()): print(values[i]) # prints '5', '12', and then '19' ``` #### 🦋 Changed * `def main():`, without the explicit `None` type, can now be used to define the entry point to a Mojo program. * The `assert_param` function has been renamed to `constrained` and is now a built-in function. * The `print` function now works on `Complex` values. #### 🛠️ Fixed * Fixed issues with print formatting for `DType.uint16` and `DType.int16`. * [Issue #499](https://github.com/modular/modular/issues/499) - Two new `rotate_right` and `rotate_left` functions have been added to the SIMD module. * [Issue #429](https://github.com/modular/modular/issues/429) - You can now construct a `Bool` from a `SIMD` type whose element-type is `DType.bool`. * [Issue #350](https://github.com/modular/modular/issues/350) - Confusing Matrix implementation * [Issue #349](https://github.com/modular/modular/issues/349) - Missing load\_tr in struct Matrix * [Issue #501](https://github.com/modular/modular/issues/501) - Missing syntax error messages in Python expressions. ### 2023-08-09 #### 🦋 Changed * The `ref` and `mutref` identifiers are now treated as keywords, which means they cannot be used as variable, attribute, or function names. These keywords are used by the "lifetimes" features, which is still in development. We can consider renaming these (as well as other related keywords) when the development work gels, support is enabled in public Mojo builds, and when we have experience using them. * The argument handling in `def` functions has changed: previously, they had special behavior that involved mutable copies in the callee. Now, we have a simple rule, which is that `def` argument default to the `owned` convention (`fn` arguments still default to the `borrowed` convention). This change is mostly an internal cleanup and simplification of the compiler and argument model, but does enable one niche use-case: you can now pass non-copyable types to `def` arguments by transferring ownership of a value into the `def` call. Before, that would not be possible because the copy was made on the callee side, not the caller's side. This also allows the explicit use of the `borrowed` keyword with a `def` that wants to opt-in to that behavior. ### 2023-08-03 #### ⭐️ New * A new `Tensor` type has been introduced. This tensor type manages its own data (unlike `NDBuffer` and `Buffer` which are just views). Therefore, the tensor type performs its own allocation and free. Here is a simple example of using the tensor type to represent an RGB image and convert it to grayscale: ```mojo from tensor import Tensor, TensorShape from utils.index import Index from random import rand let height = 256 let width = 256 let channels = 3 # Create the tensor of dimensions height, width, channels and fill with # random value. let image = rand[DType.float32](height, width, channels) # Declare the grayscale image. var gray_scale_image = Tensor[DType.float32](height, width) # Perform the RGB to grayscale transform. for y in range(height): for x in range(width): let r = image[y, x, 0] let g = image[y, x, 1] let b = image[y, x, 2] gray_scale_image[Index(y, x)] = 0.299 * r + 0.587 * g + 0.114 * b ``` #### 🛠️ Fixed * [Issue #53](https://github.com/modular/modular/issues/53) - `Int` now implements true division with the `/` operator. Similar to Python, this returns a 64-bit floating point number. The corresponding in-place operator, `/=`, has the same semantics as `//=`. ## July 2023 ### 2023-07-26 #### ⭐️ New * Types that define both `__getitem__` and `__setitem__` (i.e. where sub-scripting instances creates computed LValues) can now be indexed in parameter expressions. * Unroll decorator for loops with constant bounds and steps: * `@unroll`: Fully unroll a loop. * `@unroll(n)`: Unroll a loop by factor of n, where `n` is a positive integer. * Unroll decorator requires loop bounds and iteration step to be compiler time constant value, otherwise unrolling will fail with compilation error. This also doesn't make loop induction variable a parameter. ```mojo # Fully unroll the loop. @unroll for i in range(5): print(i) # Unroll the loop by a factor of 4 (with remainder iterations of 2). @unroll(4) for i in range(10): print(i) ``` * The Mojo REPL now prints the values of variables defined in the REPL. There is full support for scalars and structs. Non-scalar SIMD vectors are not supported at this time. #### 🛠️ Fixed * [Issue #437](https://github.com/modular/modular/issues/437) - Range can now be instantiated with a PythonObject. * [Issue #288](https://github.com/modular/modular/issues/288) - Python strings can now be safely copied. ### 2023-07-20 #### ⭐️ New * Mojo now includes a `Limits` module, which contains functions to get the max and min values representable by a type, as requested in [Issue \#51](https://github.com/modular/modular/issues/51). The following functions moved from `Math` to `Limits`: `inf()`, `neginf()`, `isinf()`, `isfinite()`. * Mojo decorators are now distinguished between "signature" and "body" decorators and are ordered. Signature decorators, like `@register_passable` and `@parameter`, modify the type of declaration before the body is parsed. Body decorators, like `@value`, modify the body of declaration after it is fully parsed. Due to ordering, a signature decorator cannot be applied after a body decorator. That means the following is now invalid: ```mojo @register_passable # error: cannot apply signature decorator after a body one! @value struct Foo: pass ``` * Global variables can now be exported in Mojo compiled archives, using the `@export` decorator. Exported global variables are public symbols in compiled archives and use the variable name as its linkage name, by default. A custom linkage name can be specified with `@export("new_name")`. This does not affect variable names in Mojo code. * Mojo now supports packages! A Mojo package is defined by placing an `__init__.mojo` or `__init__.🔥` within a directory. Other files in the same directory form modules within the package (this works exactly like it does [in Python](https://docs.python.org/3/tutorial/modules.html#packages)). Example: ```bash main.🔥 my_package/ __init__.🔥 module.🔥 my_other_package/ __init__.🔥 stuff.🔥 ``` ```mojo # main.🔥 from my_package.module import some_function from my_package.my_other_package.stuff import SomeType fn main(): var x: SomeType = some_function() ``` * Mojo now supports direct module and package imports! Modules and packages can be imported and bound to names. Module and package elements, like functions, types, global variables, and other modules, can be accessed using attribute references, like `my_module.foo`. Note that modules lack runtime representations, meaning module references cannot be instantiated. ```mojo import builtin.io as io import SIMD io.print("hello world") var x: SIMD.Float32 = 1.2 ``` #### 🦋 Changed * Reverted the feature from 2023-02-13 that allowed unqualified struct members. Use the `Self` keyword to conveniently access struct members with bound parameters instead. This was required to fix [Issue #260](https://github.com/modular/modular/issues/260). * Updated the RayTracing notebook: added step 5 to create specular lighting for more realistic images and step 6 to add a background image. #### 🛠️ Fixed * [Issue #260](https://github.com/modular/modular/issues/260) - Definitions inside structs no longer shadow definitions outside of struct definitions. ### 2023-07-12 #### ⭐️ New * Mojo now has support for global variables! This enables `var` and `let` declaration at the top-level scope in Mojo files. Global variable initializers are run when code modules are loaded by the platform according to the order of dependencies between global variables, and their destructors are called in the reverse order. * The Mojo programming manual is now written as a Jupyter notebook, and available in its entirety in the Mojo Playground (`programming-manual.ipynb`). (Previously, `HelloMojo.ipynb` included most of the same material, but it was not up-to-date.) * As a result, we've also re-written `HelloMojo.ipynb` to be much shorter and provide a more gentle first-user experience. * [`Coroutine` module documentation](/mojo/std/builtin/coroutine) is now available. Coroutines form the basis of Mojo's support for asynchronous execution. Calls to `async fn`s can be stored into a `Coroutine`, from which they can be resumed, awaited upon, and have their results retrieved upon completion. #### 🦋 Changed * `simd_bit_width` in the `TargetInfo` module has been renamed to `simdbitwidth` to better align with `simdwidthof`, `bitwidthof`, etc. #### 🛠️ Fixed * The walrus operator now works in if/while statements without parentheses, e.g. `if x := function():`. * [Issue #428](https://github.com/modular/modular/issues/428) - The `FloatLiteral` and `SIMD` types now support conversion to `Int` via the `to_int` or `__int__` method calls. The behavior matches that of Python, which rounds towards zero. ### 2023-07-05 #### ⭐️ New * Tuple expressions now work without parentheses. For example, `a, b = b, a` works as you'd expect in Python. * Chained assignments (e.g. `a = b = 42`) and the walrus operator (e.g. `some_function(b := 17)`) are now supported. #### 🦋 Changed * The `simd_width` and `dtype_simd_width` functions in the [`TargetInfo`](/mojo/std/sys/info) module have been renamed to `simdwidthof`. * The `dtype_` prefix has been dropped from `alignof`, `sizeof`, and `bitwidthof`. You can now use these functions (e.g. `alignof`) with any argument type, including `DType`. * The `inf`, `neginf`, `nan`, `isinf`, `isfinite`, and `isnan` functions were moved from the `Numerics` module to the [`Math`](/mojo/std/math/math/) module, to better align with Python's library structure. #### 🛠️ Fixed * [Issue #253](https://github.com/modular/modular/issues/253) - Issue when accessing a struct member alias without providing parameters. * [Issue #404](https://github.com/modular/modular/issues/404) - The docs now use `snake_case` for variable names, which more closely conforms to Python's style. * [Issue #379](https://github.com/modular/modular/issues/379) - Tuple limitations have been addressed and multiple return values are now supported, even without parentheses. * [Issue #347](https://github.com/modular/modular/issues/347) - Tuples no longer require parentheses. * [Issue #320](https://github.com/modular/modular/issues/320) - Python objects are now traversable via `for` loops. ## June 2023 ### 2023-06-29 #### ⭐️ New * You can now share `.ipynb` notebook files in Mojo Playground. Just save a file in the `shared` directory, and then right-click the file and select **Copy Sharable link**. To open a shared notebook, you must already have access to Mojo Playground; when you open a shared notebook, click **Import** at the top of the notebook to save your own copy. For more details about this feature, see the instructions inside the `help` directory, in the Mojo Playground file browser. #### 🦋 Changed * The `unroll2()` and `unroll3()` functions in the [`Functional`](/mojo/std/algorithm/functional) module have been renamed to overload the `unroll()` function. These functions unroll 2D and 3D loops and `unroll()` can determine the intent based on the number of input parameters. #### 🛠️ Fixed * [Issue #229](https://github.com/modular/modular/issues/229) - Issue when throwing an exception from `__init__` before all fields are initialized. * [Issue #74](https://github.com/modular/modular/issues/74) - Struct definition with recursive reference crashes. * [Issue #285](https://github.com/modular/modular/issues/285) - The [`TargetInfo`](/mojo/std/sys/info) module now includes `is_little_endian()` and `is_big_endian()` to check if the target host uses either little or big endian. * [Issue #254](https://github.com/modular/modular/issues/254) - Parameter name shadowing in nested scopes is now handled correctly. ### 2023-06-21 #### ⭐️ New * Added support for overloading on parameter signature. For example, it is now possible to write the following: ```mojo fn foo[a: Int](x: Int): pass fn foo[a: Int, b: Int](x: Int): pass ``` For details on the overload resolution logic, see the Mojo Manual section on [parameters](/mojo/manual/parameters/#overloading-on-parameters). * A new `cost_of()` function has been added to `Autotune`. This meta-function must be invoked at compile time, and it returns the number of MLIR operations in a function (at a certain stage in compilation), which can be used to build basic heuristics in higher-order generators. ```mojo from autotune import cost_of fn generator[f: fn(Int) -> Int]() -> Int: @parameter if cost_of[fn(Int) -> Int, f]() < 10: return f() else: # Do something else for slower functions... ``` * Added a new example notebook with a basic Ray Tracing algorithm. #### 🦋 Changed * The `constrained_msg()` in the `Assert` module has been renamed to `constrained()`. #### 🛠️ Fixed * Overloads marked with `@adaptive` now correctly handle signatures that differ only in declared parameter names, e.g. the following now works correctly: ```mojo @adaptive fn foobar[w: Int, T: DType]() -> SIMD[T, w]: ... @adaptive fn foobar[w: Int, S: DType]() -> SIMD[S, w]: ... ``` * [Issue #219](https://github.com/modular/modular/issues/219) - Issue when redefining a function and a struct defined in the same cell. * [Issue #355](https://github.com/modular/modular/issues/355) - The loop order in the Matmul notebook for Python and naive mojo have been reordered for consistency. The loop order now follows (M, K, N) ordering. * [Issue #309](https://github.com/modular/modular/issues/309) - Use snake case naming within the testing package and move the asserts out of the TestSuite struct. ### 2023-06-14 #### ⭐️ New * Tuple type syntax is now supported, e.g. the following works: ```mojo fn return_tuple() -> (Int, Int): return (1, 2) ``` #### 🦋 Changed * The `TupleLiteral` type was renamed to just `Tuple`, e.g. `Tuple[Int, Float]`. #### 🛠️ Fixed * [Issue #354](https://github.com/modular/modular/issues/354) - Returning a tuple doesn't work even with parens. * [Issue #365](https://github.com/modular/modular/issues/365) - Copy-paste error in `FloatLiteral` docs. * [Issue #357](https://github.com/modular/modular/issues/357) - Crash when missing input parameter to variadic parameter struct member function. ### 2023-06-07 #### ⭐️ New * Tuple syntax now works on the left-hand side of assignments (in "lvalue" positions), enabling things like `(a, b) = (b, a)`. There are several caveats: the element types must exactly match (no implicit conversions), this only works with values of `TupleLiteral` type (notably, it will not work with `PythonObject` yet) and parentheses are required for tuple syntax. #### ❌ Removed * Mojo Playground no longer includes the following Python packages (due to size, compute costs, and [environment complications](https://github.com/modular/modular/issues/300)): `torch`, `tensorflow`, `keras`, `transformers`. #### 🦋 Changed * The data types and scalar names now conform to the naming convention used by numpy. So we use `Int32` instead of `SI32`, similarly using `Float32` instead of `F32`. Closes [Issue #152](https://github.com/modular/modular/issues/152). #### 🛠️ Fixed * [Issue #287](https://github.com/modular/modular/issues/287) - computed lvalues don't handle raising functions correctly * [Issue #318](https://github.com/modular/modular/issues/318) - Large integers are not being printed correctly * [Issue #326](https://github.com/modular/modular/issues/326) - Float modulo operator is not working as expected * [Issue #282](https://github.com/modular/modular/issues/282) - Default arguments are not working as expected * [Issue #271](https://github.com/modular/modular/issues/271) - Confusing error message when converting between function types with different result semantics ## May 2023 ### 2023-05-31 #### ⭐️ New * Mojo Playground now includes the following Python packages (in response to [popular demand](https://github.com/modular/modular/discussions/173)): `torch`, `tensorflow`, `polars`, `opencv-python`, `keras`, `Pillow`, `plotly`, `seaborn`, `sympy`, `transformers`. * A new optimization is applied to non-trivial copyable values that are passed as an owned value without using the transfer (`^`) operator. Consider code like this: ```mojo var someValue: T = ... ... takeValueAsOwned(someValue) ... ``` When `takeValueAsOwned()` takes its argument as an [`owned`](/mojo/manual/values/ownership#transfer-arguments-var-and-) value (this is common in initializers for example), it is allowed to do whatever it wants with the value and destroy it when it is finished. In order to support this, the Mojo compiler is forced to make a temporary copy of the `someValue` value, and pass that value instead of `someValue`, because there may be other uses of `someValue` after the call. The Mojo compiler is now smart enough to detect when there are no uses of `someValue` later, and it will elide the copy just as if you had manually specified the transfer operator like `takeValueAsOwned(someValue^)`. This provides a nice "it just works" behavior for non-trivial types without requiring manual management of transfers. If you'd like to take full control and expose full ownership for your type, just don't make it copyable. Move-only types require the explicit transfer operator so you can see in your code where all ownership transfer happen. * Similarly, the Mojo compiler now transforms calls to `__copyinit__` methods into calls to `__moveinit__` when that is the last use of the source value along a control flow path. This allows types which are both copyable and movable to get transparent move optimization. For example, the following code is compiled into moves instead of copies even without the use of the transfer operator: ```mojo var someValue = somethingCopyableAndMovable() use(someValue) ... let otherValue = someValue # Last use of someValue use(otherValue) ... var yetAnother = otherValue # Last use of otherValue mutate(yetAnother) ``` This is a significant performance optimization for things like `PythonObject` (and more complex value semantic types) that are commonly used in a fluid programming style. These don't want extraneous reference counting operations performed by its copy constructor. If you want explicit control over copying, it is recommended to use a non-dunder `.copy()` method instead of `__copyinit__`, and recall that non-copyable types must always use of the transfer operator for those that want fully explicit behavior. #### 🛠️ Fixed * [Issue #231](https://github.com/modular/modular/issues/231) - Unexpected error when a Python expression raises an exception * [Issue #119](https://github.com/modular/modular/issues/119) - The REPL fails when a python variable is redefined ### 2023-05-24 #### ⭐️ New * `finally` clauses are now supported on `try` statements. In addition, `try` statements no longer require `except` clauses, allowing `try-finally` blocks. `finally` clauses contain code that is always executed from control-flow leaves any of the other clauses of a `try` statement by any means. #### 🦋 Changed * `with` statement emission changed to use the new `finally` logic so that ```mojo with ContextMgr(): return ``` Will correctly execute `ContextMgr.__exit__` before returning. #### 🛠️ Fixed * [Issue #204](https://github.com/modular/modular/issues/204) - Mojo REPL crash when returning a String at compile-time * [Issue #143](https://github.com/modular/modular/issues/143) - synthesized init in `@register_passable` type doesn't get correct convention. * [Issue #201](https://github.com/modular/modular/issues/201) - String literal concatenation is too eager. * [Issue #209](https://github.com/modular/modular/issues/209) - \[QoI] Terrible error message trying to convert a type to itself. * [Issue #32](https://github.com/modular/modular/issues/32) - Include struct fields in docgen * [Issue #50](https://github.com/modular/modular/issues/50) - Int to string conversion crashes due to buffer overflow * [Issue #132](https://github.com/modular/modular/issues/132) - PythonObject `to_int` method has a misleading name * [Issue #189](https://github.com/modular/modular/issues/189) - PythonObject bool conversion is incorrect * [Issue #65](https://github.com/modular/modular/issues/65) - Add SIMD constructor from Bool * [Issue #153](https://github.com/modular/modular/issues/153) - Meaning of `Time.now` function result is unclear * [Issue #165](https://github.com/modular/modular/issues/165) - Type in `Pointer.free` documentation * [Issue #210](https://github.com/modular/modular/issues/210) - Parameter results cannot be declared outside top-level in function * [Issue #214](https://github.com/modular/modular/issues/214) - Pointer offset calculations at compile-time are incorrect * [Issue #115](https://github.com/modular/modular/issues/115) - Float printing does not include the right number of digits * [Issue #202](https://github.com/modular/modular/issues/202) - `kgen.unreachable` inside nested functions is illegal * [Issue #235](https://github.com/modular/modular/issues/235) - Crash when register passable struct field is not register passable * [Issue #237](https://github.com/modular/modular/issues/237) - Parameter closure sharp edges are not documented ### 2023-05-16 #### ⭐️ New * Added missing dunder methods to `PythonObject`, enabling the use of common arithmetic and logical operators on imported Python values. * `PythonObject` is now printable from Mojo, instead of requiring you to import Python's print function. #### 🛠️ Fixed * [Issue #98](https://github.com/modular/modular/issues/98): Incorrect error with lifetime tracking in loop. * [Issue #49](https://github.com/modular/modular/issues/49): Type inference issue (?) in 'ternary assignment' operation (FloatLiteral vs. 'SIMD\[f32, 1]'). * [Issue #48](https://github.com/modular/modular/issues/48): and/or don't work with memory-only types. * [Issue #11](https://github.com/modular/modular/issues/11): `setitem` Support for `PythonObject`. ### 2023-05-11 #### ⭐️ New * `NDBuffer` and `Buffer` are now constructable via `Pointer` and `DTypePointer`. * `String` now supports indexing with either integers or slices. * Added factorial function to the `Math` module. #### 🦋 Changed * The "byref" syntax with the `&` sigil has changed to use an `inout` keyword to be more similar to the `borrowed` and `owned` syntax in arguments. Please see [Issue #7](https://github.com/modular/modular/issues/7) for more information. * Optimized the Matrix multiplication implementation in the notebook. Initially we were optimizing for expandability rather than performance. We have found a way to get the best of both worlds and now the performance of the optimized Matmul implementation is 3x faster. * Renamed the [`^` postfix operator](/mojo/manual/values/ownership#transfer-arguments-var-and-) from "consume" to "transfer." #### 🛠️ Fixed * Fixed missing overloads for `Testing.assertEqual` so that they work on `Integer` and `String` values. * [Issue #6](https://github.com/modular/modular/issues/6): Playground stops evaluating cells when a simple generic is defined. * [Issue #18](https://github.com/modular/modular/issues/18): Memory leak in Python interoperability was removed. ### 2023-05-02 #### 📢 Released * Mojo publicly launched! This was epic, with lots of great coverage online including a [wonderful post by Jeremy Howard](https://www.fast.ai/posts/2023-05-03-mojo-launch.html). The team is busy this week. #### ⭐️ New * Added a Base64 encoding function to perform base64 encoding on strings. #### 🦋 Changed * Decreased memory usage of serialization of integers to strings. * Speedup the sort function. #### 🛠️ Fixed * Fixed time unit in the `sleep` function. ## April 2023 ### Week of 2023-04-24 * 📢 The default behavior of nested functions has been changed. Mojo nested functions that capture are by default are non-parametric, runtime closures, meaning that: ```mojo def foo(x): # This: def bar(y): return x * y # Is the same as: let bar = lambda y: x * y ``` These closures cannot have input or result parameters, because they are always materialized as runtime values. Values captured in the closure (`x` in the above example), are captured by copy: values with copy constructors cannot be copied and captures are immutable in the closure. Nested functions that don't capture anything are by default "parametric" closures: they can have parameters and they can be used as parameter values. To restore the previous behavior for capturing closures, "parametric, capture-by-unsafe-reference closures", tag the nested function with the `@parameter` decorator. * 📢 Mojo now has full support for "runtime" closures: nested functions that capture state materialized as runtime values. This includes taking the address of functions, indirect calls, and passing closures around through function arguments. Note that capture-by-reference is still unsafe! You can also take references to member functions with instances of that class using `foo.member_function`, which creates a closure with `foo` bound to the `self` argument. * 📢 Mojo now supports Python style `with` statements and context managers. These things are very helpful for implementing things like our trace region support and things like Runtime support. A context manager in Mojo implements three methods: ```mojo fn __enter__(self) -> T: fn __exit__(self): fn __exit__(self, err: Error) -> Bool: ``` The first is invoked when the context is entered, and returns a value that may optionally be bound to a target for use in the with body. If the with block exits normally, the second method is invoked to clean it up. If an error is raised, the third method is invoked with the Error value. If that method returns true, the error is considered handled, if it returns false, the error is re-thrown so propagation continues out of the 'with' block. * 📢 Mojo functions now support variable scopes! Explicit `var` and `let` declarations inside functions can shadow declarations from higher "scopes", where a scope is defined as any new indentation block. In addition, the `for` loop iteration variable is now scoped to the loop body, so it is finally possible to write ```mojo for i in range(1): pass for i in range(2): pass ``` * 📢 Mojo now supports an `@value` decorator on structs to reduce boilerplate and encourage best practices in value semantics. The `@value` decorator looks to see the struct has a fieldwise initializer (which has arguments for each field of the struct), a `__copyinit__` method, and a `__moveinit__` method, and synthesizes the missing ones if possible. For example, if you write: ```mojo @value struct MyPet: var name: String var age: Int ``` The `@value` decorator will synthesize the following members for you: ```mojo fn __init__(inout self, owned name: String, age: Int): self.name = name^ self.age = age fn __copyinit__(inout self, existing: Self): self.name = existing.name self.age = existing.age fn __moveinit__(inout self, owned existing: Self): self.name = existing.name^ self.age = existing.age ``` This decorator can greatly reduce the boilerplate needed to define common aggregates, and gives you best practices in ownership management automatically. The `@value` decorator can be used with types that need custom copy constructors (your definition wins). We can explore having the decorator take arguments to further customize its behavior in the future. * 📚 Memcpy and memcmp now consistently use count as the byte count. * 📚 Add a variadic string join on strings. * 📚 Introduce a `reduce_bit_count` method to count the number of 1 across all elements in a SIMD vector. * 📚 Optimize the `pow` function if the exponent is integral. * 📚 Add a `len` function which dispatches to `__len__` across the different structs that support it. ### Week of 2023-04-17 * 📢 Error messages have been significantly improved, thanks to prettier printing for Mojo types in diagnostics. * 📢 Variadic values can now be indexed directly without wrapping them in a `VariadicList`! * 📢 `let` declarations in a function can now be lazily initialized, and `var` declarations that are never mutated get a warning suggesting they be converted to a `let` declaration. Lazy initialization allows more flexible patterns of initialization than requiring the initializer be inline, e.g.: ```mojo let x: Int if cond: x = foo() else: x = bar() use(x) ``` * 📢 Functions defined with `def` now return `object` by default, instead of `None`. This means you can return values (convertible to `object`) inside `def` functions without specifying a return type. * 📢 The `@raises` decorator has been removed. Raising `fn` should be declared by specifying `raises` after the function argument list. The rationale is that `raises` is part of the type system, instead of a function modifier. * 📢 The `BoolLiteral` type has been removed. Mojo now emits `True` and `False` directly as `Bool`. * 📢 Syntax for function types has been added. You can now write function types with `fn(Int) -> String` or `async def(&String, *Int) -> None`. No more writing `!kgen.signature` types by hand! * 📢 Float literals are not emitted as `FloatLiteral` instead of an MLIR `f64` type! * 📢 Automatic destructors are now supported by Mojo types, currently spelled `fn __del___(owned self):` (the extra underscore will be dropped shortly). These destructors work like Python object destructors and similar to C++ destructors, with the major difference being that they run "as soon as possible" after the last use of a value. This means they are not suitable for use in C++-style RAII patterns (use the `with` statement for that, which is currently unsupported). These should be generally reliable for both memory-only and register-passable types, with the caveat that closures are known to *not* capture values correctly. Be very careful with interesting types in the vicinity of a closure! * A new (extremely dangerous!) builtin function is available for low-level ownership muckery. The `__get_address_as_owned_value(x)` builtin takes a low-level address value (of `!kgen.pointer` type) and returns an `owned` value for the memory that is pointed to. This value is assumed live at the invocation of the builtin, but is "owned" so it needs to be consumed by the caller, otherwise it will be automatically destroyed. This is an effective way to do a "placement delete" on a pointer. ```mojo # "Placement delete": destroy the initialized object begin pointed to. _ = __get_address_as_owned_value(somePointer.value) # Result value can be consumed by anything that takes it as an 'owned' # argument as well. consume(__get_address_as_owned_value(somePointer.value)) ``` * Another magic operator, named `__get_address_as_uninit_lvalue(x)` joins the magic LValue operator family. This operator projects a pointer to an LValue like `__get_address_as_lvalue(x)`. The difference is that `__get_address_as_uninit_lvalue(x)` tells the compiler that the pointee is uninitialized on entry and initialized on exit, which means that you can use it as a "placement new" in C++ sense. `__get_address_as_lvalue(x)` tells the compiler that the pointee is initialized already, so reassigning over it will run the destructor. ```mojo # "*Re*placement new": destroy the existing SomeHeavy value in the memory, # then initialize a new value into the slot. __get_address_as_lvalue(somePointer.value) = SomeHeavy(4, 5) # Ok to use an lvalue, convert to borrow etc. use(__get_address_as_lvalue(somePointer.value)) # "Placement new": Initialize a new value into uninitialied memory. __get_address_as_uninit_lvalue(somePointer.value) = SomeHeavy(4, 5) # Error, cannot read from uninitialized memory. use(__get_address_as_uninit_lvalue(somePointer.value)) ``` Note that `__get_address_as_lvalue` assumes that there is already a value at the specified address, so the assignment above will run the `SomeHeavy` destructor (if any) before reassigning over the value. * 📢 Implement full support for `__moveinit__` (aka move constructors) This implements the ability for memory-only types to define two different types of move ctors if they'd like: 1. `fn __moveinit__(inout self, owned existing: Self)`: Traditional Rust style moving constructors that shuffles data around while taking ownership of the source binding. 2. `fn __moveinit__(inout self, inout existing: Self):`: C++ style "stealing" move constructors that can be used to take from an arbitrary LValue. This gives us great expressive capability (better than Rust/C++/Swift) and composes naturally into our lifetime tracking and value categorization system. * The `__call__` method of a callable type has been relaxed to take `self` by borrow, allow non-copyable callees to be called. * Implicit conversions are now invoked in `raise` statements properly, allowing converting strings to `Error` type. * Automatic destructors are turned on for `__del__` instead of `__del___`. * 📚 Add the builtin FloatLiteral type. * 📚 Add integral `floordiv` and `mod` for the SIMD type that handle negative values. * 📚 Add an F64 to String converter. * 📚 Make the `print` function take variadic inputs. ### Week of 2023-04-10 * 📢 Introduce consume operator `x^` This introduces the postfix consume operator, which produces an RValue given a lifetime tracked object (and, someday, a movable LValue). * Mojo now automatically synthesizes empty destructor methods for certain types when needed. * The `object` type has been built out into a fully-dynamic type, with dynamic function dispatch, with full error handling support. ```mojo def foo(a) -> object: return (a + 3.45) < [1, 2, 3] # raises a TypeError ``` * 📢 The `@always_inline` decorator is no longer required for passing capturing closures as parameters, for both the functions themselves as functions with capturing closures in their parameters. These functions are still inlined but it is an implementation detail of capturing parameter closures. Mojo now distinguishes between capturing and non-capturing closures. Nested functions are capturing by default and can be made non-capturing with the `@noncapturing` decorator. A top-level function can be passed as a capturing closure by marking it with the `@closure` decorator. * 📢 Support for list literals has been added. List literals `[1, 2, 3]` generate a variadic heterogeneous list type. * Variadics have been extended to work with memory-primary types. * Slice syntax is now fully-supported with a new builtin `slice` object, added to the compiler builtins. Slice indexing with `a[1:2:3]` now emits calls to `__setitem__` and `__getitem__` with a slice object. * Call syntax has been wired up to `__call__`. You can now `f()` on custom types! * Closures are now explicitly typed as capturing or non-capturing. If a function intends to accept a capturing closure, it must specify the `capturing` function effect. * 📚 Add a `Tile2D` function to enable generic `2D` tiling optimizations. * 📚 Add the `slice` struct to enable getting/setting spans of elements via `getitem`/`setitem`. * 📚 Add syntax sugar to autotuning for both specifying the autotuned values, searching, and declaring the evaluation function. ### Week of 2023-04-03 * The `AnyType` and `NoneType` aliases were added and auto-imported in all files. * 📢 The Mojo VS Code extension has been improved with docstring validation. It will now warn when a function's docstring has a wrong argument name, for example. * 📢 A new built-in literal type `TupleLiteral` was added in `_CompilerBuiltin`. It represents literal tuple values such as `(1, 2.0)` or `()`. * 📢 The `Int` type has been moved to a new `Builtin` module and is auto-imported in all code. The type of integer literals has been changed from the MLIR `index` type to the `Int` type. * Mojo now has a powerful flow-sensitive uninitialized variable checker. This means that you need to initialize values before using them, even if you overwrite all subcomponents. This enables the compiler to reason about the true lifetime of values, which is an important stepping stone to getting automatic value destruction in place. * 📢 Call syntax support has been added. Now you can directly call an object that implements the `__call__` method, like `foo(5)`. * 📢 The name for copy constructors got renamed from `__copy__` to `__copyinit__`. Furthermore, non-`@register_passable` types now implement it like they do an init method where you fill in a by-reference self, for example: ```mojo fn __copyinit__(inout self, existing: Self): self.first = existing.first self.second = existing.second ``` This makes copy construction work more similarly to initialization, and still keeps copies `x = y` distinct from initialization `x = T(y)`. * 📢 Initializers for memory-primary types are now required to be in the form `__init__(inout self, ...):` with a None result type, but for register primary types, it remains in the form `__init__(...) -> Self:`. The `T{}` initializer syntax has been removed for memory-primary types. * Mojo String literals now emit a builtin `StringLiteral` type! One less MLIR type to worry about. * New `__getattr__` and `__setattr__` dunder methods were added. Mojo calls these methods on a type when attempting member lookup of a non-static member. This allows writing dynamic objects like `x.foo()` where `foo` is not a member of `x`. * Early destructor support has been added. Types can now define a special destructor method `__del___` (note three underscores). This is an early feature and it is still being built out. There are many caveats, bugs, and missing pieces. Stay tuned! * 📚 Integer division and mod have been corrected for rounding in the presence of negative numbers. * 📚 Add scalar types (UI8, SI32, F32, F64, etc.) which are aliases to `SIMD[1, type]`. ## March 2023 ### Week of 2023-03-27 * 📢 Parameter names are no longer load-bearing in function signatures. This gives more flexibility in defining higher-order functions, because the functions passed as parameters do not need their parameter names to match. ```mojo # Define a higher-order function... fn generator[ func: __mlir_type[`!kgen.signature<`, Int, `>() -> !kgen.none`] ](): pass # Int parameter is named "foo". fn f0[foo: Int](): pass # Int parameter is named "bar". fn f1[bar: Int](): pass fn main(): # Both can be used as `func`! generator[f0]() generator[f1]() ``` Stay tuned for improved function type syntax... * 📢 Two magic operators, named `__get_lvalue_as_address(x)` and `__get_address_as_lvalue` convert stored LValues to and from `!kgen.pointer` types (respectively). This is most useful when using the `Pointer[T]` library type. The `Pointer(to=lvalue)` method uses the first one internally. The second one must currently be used explicitly, and can be used to project a pointer to a reference that you can pass around and use as a self value, for example: ```mojo # "Replacement new" SomeHeavy value into the memory pointed to by a # Pointer[SomeHeavy]. __get_address_as_lvalue(somePointer.value) = SomeHeavy(4, 5) ``` Note that `__get_address_as_lvalue` assumes that there is already a value at the specified address, so the assignment above will run the `SomeHeavy` destructor (if any) before reassigning over the value. * The `(((x)))` syntax is \_\_mlir\_op has been removed in favor of `__get_lvalue_as_address` which solves the same problem and is more general. * 📢 When using a mutable `self` argument to a struct `__init__` method, it now must be declared with `&`, like any other mutable method. This clarifies the mutation model by making `__init__` consistent with other mutating methods. * 📚 Add variadic string join function. * 📚 Default initialize values with 0 or null if possible. * 📚 Add compressed, aligned, and mask store intrinsics. ### Week of 2023-03-20 * Initial `String` type is added to the standard library with some very basic methods. * Add `DimList` to remove the need to use an MLIR list type throughout the standard library. * 📢 The `__clone__` method for copying a value is now named `__copy__` to better follow Python term of art. * 📢 The `__copy__` method now takes its self argument as a "read" value, instead of taking it by reference. This makes it easier to write, works for `@register_passable` types, and exposes more optimization opportunities to the early optimizer and dataflow analysis passes. ```mojo # Before: fn __clone__(inout self) -> Self: ... # After: fn __copy__(self) -> Self: ... ``` * 📢 A new `@register_passable("trivial")` may be applied to structs that have no need for a custom `__copy__` or `__del__` method, and whose state is only made up of `@register_passable("trivial")` types. This eliminates the need to define `__copy__` boilerplate and reduces the amount of IR generated by the compiler for trivial types like `Int`. * You can now write back to attributes of structs that are produced by a computed lvalue expression. For example `a[i].x = ..` works when `a[i]` is produced with a `__getitem__`/`__setitem__` call. This is implemented by performing a read of `a[i]`, updating the temporary, then doing a writeback. * The remaining hurdles to using non-parametric, `@register_passable` types as parameter values have been cleared. Types like `Int` should enjoy full use as parameter values. * Parameter pack inference has been added to function calls. Calls to functions with parameter packs can now elide the pack types: ```mojo fn foo[*Ts: AnyType](*args: *Ts): pass foo(1, 1.2, True, "hello") ``` Note that the syntax for parameter packs has been changed as well. * 📚 Add the runtime string type. * 📚 Introduce the DimList struct to remove the need to use low-level MLIR operations. ### Week of 2023-03-13 * 📢 Initializers for structs now use `__init__` instead of `__new__`, following standard practice in Python. You can write them in one of two styles, either traditional where you mutate self: ```mojo fn __init__(self, x: Int): self.x = x ``` or as a function that returns an instance: ```mojo fn __init__(x: Int) -> Self: return Self {x: x} ``` Note that `@register_passable` types must use the later style. * 📢 The default argument convention is now the `borrowed` convention. A "read" argument is passed like a C++ `const&` so it doesn't need to invoke the copy constructor (aka the `__clone__` method) when passing a value to the function. There are two differences from C++ `const&`: 1. A future borrow checker will make sure there are no mutable aliases with an immutable borrow. 2. `@register_passable` values are passed directly in an SSA register (and thus, usually in a machine register) instead of using an extra reference wrapper. This is more efficient and is the 'right default' for `@register_passable` values like integers and pointers. This also paves the way to remove the reference requirement from `__clone__` method arguments, which will allow us to fill in more support for them. * Support for variadic pack arguments has been added to Mojo. You can now write heterogeneous variadic packs like: ```mojo fn foo[*Ts: AnyType](args*: Ts): pass foo[Int, F32, String, Bool](1, 1.5, "hello", True) ``` * The `owned` argument convention has been added. This argument convention indicates that the function takes ownership of the argument and is responsible for managing its lifetime. * The `borrowed` argument convention has been added. This convention signifies the callee gets an immutable shared reference to a value in the caller's context. * 📚 Add the `getenv` function to the `OS` module to enable getting environment variables. * 📚 Enable the use of dynamic strides in `NDBuffer`. ### Week of 2023-03-06 * 📢 Support added for using capturing async functions as parameters. * 📢 Returning result parameters has been moved from `return` statements to a new `param_return` statement. This allows returning result parameters from throwing functions: ```mojo @raises fn foo[() -> out: Int](): param_return[42] raise Error() ``` And returning different parameters along `@parameter if` branches: ```mojo fn bar[in: Bool -> out: Int](): @parameter if in: param_return[1] else: param_return[2] ``` * 📢 Mojo now supports omitting returns at the end of functions when they would not reachable. For instance, ```mojo fn foo(cond: Bool) -> Int: if cond: return 0 else: return 1 fn bar() -> Int: while True: pass ``` * String literals now support concatenation, so `"hello " "world"` is treated the same as `"hello world"`. * Empty bodies on functions, structs, and control flow statements are no longer allowed. Please use `pass` in them to explicitly mark that they are empty, just like in Python. * 📢 Structs in Mojo now default to living in memory instead of being passed around in registers. This is the right default for generality (large structures, structures whose pointer identity matters, etc) and is a key technology that enables the borrow model. For simple types like `Int` and `SIMD`, they can be marked as `@register_passable`. Note that memory-only types currently have some limitations: they cannot be used in generic algorithms that take and return a `!mlirtype` argument, and they cannot be used in parameter expressions. Because of this, a lot of types have to be marked `@register_passable` just to work around the limitations. We expect to enable these use-cases over time. * 📢 Mojo now supports computed lvalues, which means you can finally assign to subscript expressions instead of having to call `__setitem__` explicitly. Some details on this: Mojo allows you to define multiple `__setitem__` overloads, but will pick the one that matches your `__getitem__` type if present. It allows you to pass computed lvalues into inout arguments by introducing a temporary copy of the value in question. * Mojo now has much better support for using register-primary struct types in parameter expressions and as the types of parameter values. This will allow migration of many standard library types away from using bare MLIR types like `__mlir_type.index` and towards using `Int`. This moves us towards getting rid of MLIR types everywhere and makes struct types first-class citizens in the parameter system. * 📚 Add a `sort` function. * 📚 Add non-temporal store to enable cache bypass. ## February 2023 ### Week of 2023-02-27 * 📢 The `@interface`, `@implements`, and `@evaluator` trio of decorators have been removed, replaced by the `@parameter if` and `@adaptive` features. * 📢 Parameter inference can now infer the type of variadic lists. * 📢 Memory primary types are now supported in function results. A result slot is allocated in the caller, and the callee writes the result of the function into that slow. This is more efficient for large types that don't fit into registers neatly! And initializers for memory-primary types now initialize the value in-place, instead of emitting a copy! * Support for `let` decls of memory primary types has been implemented. These are constant, ready-only values of memory primary types but which are allocated on the function stack. * Overload conversion resolution and parameter inference has been improved: 1. Inference now works with `let` decls in some scenarios that weren't working before. 2. Parameter bindings can now infer types into parameter expressions. This helps resolve higher-order functions in parameter expressions. * 📚 Optimize floor, ceil, and ldexp on X86 hardware. * 📚 Implement the log math function. ### Week of 2023-02-20 * 📢 A new `@__memory_primary` struct decorator has been introduced. Memory primary types must always have an address. For instance, they are always stack-allocated when declared in a function and their values are passed into function calls by address instead of copy. This is in contract with register primary types that may not have an address, and which are passed by value in function calls. Memory-primary fields are not allowed inside register-primary structs, because struct elements are stored in-line. * 📢 A new `_CompilerBuiltin` module was added. This module defines core types and functions of the language that are referenced by the parser, and hence, is auto-imported by all other modules. For example new types for literal values like the boolean True/False will be included in `_CompilerBuiltin`. * 📢 A special `__adaptive_set` property can be accessed on a function reference marked as `@adaptive`. The property returns the adaptive overload set of that function. The return type is a `!kgen.variadic`. This feature is useful to implement a generic `evaluate` function in the standard library. * 📢 A new built-in literal type `BoolLiteral` was added in `_CompilerBuiltin`. It represents the literal boolean values `True` and `False`. This is the first Mojo literal to be emitted as a standard library type! * 📚 Add the prefetch intrinsic to enable HW prefetching a cache line. * 📚 Add the `InlinedFixedVector`, which is optimized for small vectors and stores values on both the stack and the heap. ### Week of 2023-02-13 * Unqualified lookups of struct members apply contextual parameters. This means for instance that you can refer to static methods without binding the struct parameters. ```mojo struct Foo[x: Int]: @staticmethod bar(): pass foo(self): bar() # implicitly binds to Foo[x].bar() Foo[2].bar() # explicitly bind to another parameter ``` * 📢 A new `Self` type refers to the enclosing type with all parameters bound to their current values. This is useful when working with complex parametric types, e.g.: ```mojo struct MyArray[size: Int, element_type: type]: fn __new__() -> Self: return Self {...} ``` which is a lot nicer than having to say `MyArray[size, element_type]` over and over again. * 📢 Mojo now supports an `@adaptive` decorator. This decorator will supersede interfaces, and it represents an overloaded function that is allowed to resolve to multiple valid candidates. In that case, the call is emitted as a fork, resulting in multiple function candidates to search over. ```mojo @adaptive fn sort(arr: ArraySlice[Int]): bubble_sort(arr) @adaptive fn sort(arr: ArraySlice[Int]): merge_sort(arr) fn concat_and_sort(lhs: ArraySlice[Int], rhs: ArraySlice[Int]): let arr = lhs + rhs sort(arr) # this forks compilation, creating two instances # of the surrounding function ``` * 📢 Mojo now requires that types implement the `__clone__` special member in order to copy them. This allows the safe definition of non-copyable types like Atomic. Note that Mojo still doesn't implement destructors, and (due to the absence of non-mutable references) it doesn't actually invoke the `__clone__` member when copying a let value. As such, this forces to you as a Mojo user to write maximal boilerplate without getting much value out of it. In the future, we will reduce the boilerplate with decorators, and we will actually start using it. This will take some time to build out though. * 📢 A special `__mlir_region` statement was added to provide stronger invariants around defining MLIR operation regions in Mojo. It similar syntax to function declarations, except it there are no results and no input conventions. * 📚 Implement the log math function. * 📚 Improve the DType struct to enable compile-time equality checks. * 📚 Add the Complex struct class. ### Week of 2023-02-06 * 📢 The `if` statement now supports a `@parameter` decorator, which requires its condition to be a parameter expression, but which only emits the 'True' side of the condition to the binary, providing a "static if" functionality. This should eliminate many uses of `@interface` that are just used to provide different constraint on the implementations. * 📢 `fn main():` is now automatically exported and directly runnable by the command-line `mojo` tool. This is a stop-gap solution to enable script-like use cases until we have more of the language built out. * 🪦 The `@nodebug_inline` feature has been removed, please use `@alwaysinline("nodebug")` for methods that must be inlined and that we don't want to step into. * 📢 Python chained comparisons, ex. `a < b < c`, are now supported in Mojo. * 📢 Functions can now be defined with default argument values, such as `def f(x: Int, y: Int = 5):`. The default argument value is used when callers do not provide a value for that argument: `f(3)`, for example, uses the default argument value of `y = 5`. * Unused coroutine results are now nicely diagnosed as "missing await" warnings. * 📚 Introduce a vectorized reduction operations to the SIMD type. ## January 2023 ### Week of 2023-01-30 * A basic Mojo language server has been added to the VS Code extension, which parses your code as you write it, and provides warnings, errors, and fix-it suggestions! * 💯 The Mojo standard library is now implicitly imported by default. * The coroutine lowering support was reworked and a new `Coroutine[T]` type was implemented. Now, the result of a call to an async function MUST be wrapped in a `Coroutine[T]`, or else memory will leak. In the future, when Mojo supports destructors and library types as literal types, the results of async function calls will automatically wrapped in a `Coroutine[T]`. But today, it must be done manually. This type implements all the expected hooks, such as `__await__`, and `get()` to retrieve the result. Typical usage: ```mojo async fn add_three(a: Int, b: Int, c: Int) -> Int: return a + b + c async fn call_it(): let task: Coroutine[Int] = add_three(1, 2, 3) print(await task) ``` * ⭐️ We now diagnose unused expression values at statement context in `fn` declarations (but not in `def`s). This catches bugs with unused values, e.g. when you forget the parens to call a function. * 📢 An `@always_inline("nodebug")` function decorator can be used on functions that need to be force inlined, but when they should not have debug info in the result. This should be used on methods like `Int.__add__` which should be treated as builtin. * 📢 The `@export` decorator now supports an explicit symbol name to export to, for example: ```mojo @export("baz") # exported as 'baz' fn some_mojo_fn_name(): ``` * 📢 🚧 Subscript syntax is now wired up to the `__getitem__` dunder method. This allows type authors to implement the `__getitem__` method to enable values to be subscripted. This is an extended version of the Python semantics (given we support overloading) that allows you to define N indices instead of a single version that takes a tuple (also convenient because we don't have tuples yet). Note that this has a very, very important limitation: subscripts are NOT wired up to `__setitem__` yet. This means that you can read values with `.. = v[i]` but you cannot store to them with `v[i] = ..`. For this, please continue to call `__setitem__` directly. * 📢 Function calls support parameter inference. For calls to functions that have an insufficient number of parameters specified at the callsite, we can now infer them from the argument list. We do this by matching up the parallel type structure to infer what the parameters must be. Note that this works left to right in the parameter list, applying explicitly specified parameters before trying to infer new ones. This is similar to how C++ does things, which means that you may want to reorder the list of parameters with this in mind. For example, a `dyn_cast`-like function will be more elegant when implemented as: `fn dyn_cast[DstType: type, SrcType: type](src: SrcType) -> DstType:` Than with the `SrcType`/`DstType` parameters flipped around. * 📚 Add the growable Dynamic vector struct. ### Week of 2023-01-23 * Inplace operations like `+=`/`__iadd__` may now take `self` by-val if they want to, instead of requiring it to be by-ref. * ⭐️ Inplace operations are no longer allowed to return a non-None value. The corresponding syntax is a statement, not an expression. * A new `TaskGroup` type was added to the standard library. This type can be used to schedule multiple tasks on a multi-threaded workqueue to be executed in parallel. An async function can `await` all the tasks at once with the taskgroup. * 📢 We now support for loops! A type that defines an `__iter__` method that returns a type that defines `__next__` and `__len__` methods is eligible to be used in the statement `for el in X()`. Control flow exits the loop when the length is zero. This means things like this now work: ```mojo for item in range(start, end, step): print(item) ``` * Result parameters now have names. This is useful for referring to result parameters in the return types of a function: ```mojo fn return_simd[() -> nelts: Int]() -> SIMD[f32, nelts]: ``` * 📢 We now support homogeneous variadics in value argument lists, using the standard Python `fn thing(*args: Int):` syntax! Variadics also have support in parameter lists: ```mojo fn variadic_params_and_args[*a: Int](*b: Int): print(a[0]) print(b[1]) ``` * 📚 Add the range struct to enable `for ... range(...)` loops. * 📚 Introduce the unroll generator to allow one to unroll loops via a library function. ### Week of 2023-01-16 * 📢 Struct field references are now supported in parameter context, so you can use `someInt.value` to get the underlying MLIR thing out of it. This should allow using first-class types in parameters more widely. * 📢 We now support "pretty" initialization syntax for structs, e.g.: ```mojo struct Int: var value: __mlir_type.index fn __new__(value: __mlir_type.index) -> Int: return Int {value: value} ``` This eliminates the need to directly use the MLIR `lit.struct.create` op in struct initializers. This syntax may change in the future when ownership comes in, because we will be able to support the standard `__init__` model then. * 📢 It is now possible to attach regions to `__mlir_op` operations. This is done with a hack that allows an optional `_region` attribute that lists references to the region bodies (max 1 region right now due to lack of list `[]` literal). * Nested functions now parse, e.g.: ```mojo fn foo(): fn bar(): pass bar() ``` * Python-style `async` functions should now work and the `await` expression prefix is now supported. This provides the joy of async/await syntactic sugar when working with asynchronous functions. This is still somewhat dangerous to use because we don't have proper memory ownership support yet. * String literals are now supported. * Return processing is now handled by a dataflow pass inside the compiler, so it is possible to return early out of if statements. * The parser now supports generating 'fixit' hints on diagnostics, and uses them when a dictionary literal uses a colon instead of equal, e.g.: ```log x.mojo:8:48: error: expected ':' in subscript slice, not '=' return __mlir_op.`lit.struct.create`[value = 42]() ^ : ``` * 📚 Add reduction methods which operate on buffers. * 📚 Add more math functions like sigmoid, sqrt, rsqrt, etc. * 📚 Add partial load / store which enable loads and stores that are predicated on a condition. ### Week of 2023-01-09 * The `/` and `*` markers in function signatures are now parsed and their invariants are checked. We do not yet support keyword arguments yet though, so they aren't very useful. * Functions now support a new `@nodebug_inline` decorator. (Historical note: this was later replaced with `@alwaysinline("nodebug")`). Many of the things at the bottom level of the Mojo stack are trivial zero-abstraction wrappers around MLIR things, for example, the `+` operator on Int or the `__bool__` method on Bool itself. These operators need to be force inlined even at -O0, but they have some additional things that we need to wrestle with: 1. In no case would a user actually want to step into the `__bool__` method on Bool or the + method on Int. This would be terrible debugger QoI for unless you're debugging Int itself. We need something like `__always_inline__, __nodebug__` attributes that clang uses in headers like xmmintrin.h. 2. Similarly, these "operators" should be treated by users as primitives: they don't want to know about MLIR or internal implementation details of Int. 3. These trivial zero abstraction things should be eliminated early in the compiler pipeline so they don't slow down the compiler, bloating out the call graph with trivial leaves. Such thing slows down the elaborator, interferes with basic MLIR things like fold(), bloats out the IR, or bloats out generated debug info. 4. In a parameter context, we want some of these things to get inlined so they can be simplified by the attribute logic and play more nicely with canonical types. This is just a nice to have thing those of us who have to stare at generated IR. The solution to this is a new `@nodebug_inline` decorator. This decorator causes the parser to force-inline the callee instead of generating a call to it. While doing so, it gives the operations the location of the call itself (that's the "nodebug" part) and strips out let decls that were part of the internal implementation details. This is a super-power-user-feature intended for those building the standard library itself, so it is intentionally limited in power and scope: It can only be used on small functions, it doesn't support regions, by-ref, throws, async, etc. * Separately, we now support an `@alwaysInline` decorator on functions. This is a general decorator that works on any function, and indicates that the function must be inlined. Unlike `@nodebug_inline`, this kind of inlining is performed later in the compilation pipeline. * The `__include` hack has been removed now that we have proper import support. * `__mlir_op` can now get address of l-value: You can use magic `(((x)))` syntax in \_\_mlir\_op that forces the `x` expression to be an lvalue, and yields its address. This provides an escape hatch (isolated off in `__mlir_op` land) that allows unsafe access to lvalue addresses. * We now support `__rlshift__` and `__rtruediv__`. * 📢 The parser now resolves scoped alias references. This allows us to support things like `SomeType.someAlias`, forward substituting the value. This unblocks use of aliases in types like `DType`. We'd like to eventually preserve the reference in the AST, but this unblocks library development. * 📚 Add a `now` function and `Benchmark` struct to enable timing and benchmarking. * 📚 Move more of the computation in `NDBuffer` from runtime to compile time if possible (e.g. when the dimensions are known at compile time). ### Week of 2023-01-02 * 📚 Added the `print` function which works on Integers and SIMD values. * The frontend now has a new diagnostic subsystem used by the `kgen` tool (but not by `kgen-translate` for tests) that supports source ranges on diagnostics. Before we'd emit an error like: ```log x.mojo:13:3: error: invalid call to 'callee': in argument #0, value of type '$F32::F32' cannot be converted to expected type '$int::Int' callee(1.0+F32(2.0)) ^ x.lit:4:1: note: function declared here fn callee(a: Int): ^ ``` now we produce: ```log x.mojo:13:3: error: invalid call to 'callee': in argument #0, value of type '$F32::F32' cannot be converted to expected type '$int::Int' callee(1.0+F32(2.0)) ^ ~~~~~~~~~~~~ x.lit:4:1: note: function declared here fn callee(a: Int): ^ ``` * 📢 Parameter results are now supported in a proper way. They are now forward declared with an alias declaration and then bound in a call with an arrow, e.g.: ```mojo alias a: __mlir_type.index alias b: __mlir_type.index idx_result_params[xyz * 2 -> a, b]() ``` * Various minor issues with implicit conversions are fixed. For instances, implicit conversions are now supported in parameter binding contexts and `alias` declarations with explicit types. * Doc strings are allowed on functions and structs, but they are currently discarded by the parser. * 📚 Add a `print` method!!! * 📚 Demonstrate a naive matmul in Mojo. * 📚 Initial work on functions that depend on types (e.g. FPUtils, nan, inf, etc.) * 📚 Allow one to query hardware properties such as simd\_width, os, etc. via TargetInfo at compile time. ## December 2022 ### Week of 2022-12-26 * 📢 You can now call functions in a parameter context! Calling a function in a parameter context will evaluate the function at compile time. The result can then be used as parameter values. For example, ```mojo fn fma(x: Int, y: Int, z: Int) -> Int: return a + b * c fn parameter_call(): alias nelts = fma(32, 2, 16) var x: SIMD[f32, nelts] ``` * You can now disable printing of types in an `__mlir_attr` substitution by using unary `+` expression. * 📢 `let` declarations are now supported in functions. `let` declarations are local run-time constant values, which are always rvalues. They complement 'var' decls (which are mutable lvalues) and are the normal thing to use in most cases. They also generate less IR and are always in SSA form when initialized. We will want to extend this to support 'let' decls in structs at some point and support lazy initialized 'let' declarations (using dataflow analysis) but that isn't supported yet. * 📚 Add the `NDBuffer` struct. * Happy new year. ### Week of 2022-12-19 * 📚 Start of the Standard library: 1. Added Integer and SIMD structs to bootstrap the standard library. 2. Added very basic buffer data structure. * We have basic support for parsing parameter results in function calls! Result parameters are an important Mojo metaprogramming feature. They allow functions to return compile-time constants. ```mojo fn get_preferred_simdwidthof[() -> nelts: Int](): return[2] fn vectorized_function(): get_preferred_simdwidthof[() -> nelts]() var x: SIMD[f32, nelts] ``` * Types can now be used as parameters of `!kgen.mlirtype` in many more cases. * MLIR operations with zero results don't need to specify `_type: []` anymore. * We support parsing triple quoted strings, for writing docstrings for your functions and structs! * A new `__mlir_type[a,b,c]` syntax is available for substituting into MLIR types and attributes is available, and the old placeholder approach is removed. This approach has a few advantages beyond what placeholders do: 1. It's simpler. 2. It doesn't form the intermediate result with placeholders, which gets rejected by MLIR's semantic analysis, e.g. the complex case couldn't be expressed before. 3. It provides a simple way to break long attrs/types across multiple lines. * We now support an `@evaluator` decorator on functions for KGEN evaluators. This enables specifying user-defined interface evaluators when performing search during compilation. * 📢 `import` syntax is now supported! This handles packaging imported modules into file ops, enables effective isolation from the other decls. "import" into the desired context is just aliasing decls, with the proper symbols references handle automatically during IR generation. As a starting point, this doesn't handle any notion of packages (as those haven't been sketched out enough). * 📢 Reversed binary operators (like `__radd__`) are now looked up and used if the forward version (like `__add__`) doesn't work for some reason. * 📢 Implicit conversions are now generally available, e.g. in assign statements, variable initializers etc. There are probably a few more places they should work, but we can start eliminating all the extraneous explicit casts from literals now. * Happy Holidays ### Week of 2022-12-12 * 📢 Function overloading now works. Call resolution filters candidate list according to the actual parameter and value argument specified at the site of the call, diagnosing an error if none of the candidates are viable or if multiple are viable and ambiguous. We also consider implicit conversions in overload look: ```mojo fn foo(x: Int): pass fn foo(x: F64): pass foo(Int(1)) # resolves to the first overload foo(1.0) # resolves to the second overload foo(1) # error: both candidates viable with 1 implicit conversion! ``` * The short circuiting binary `and` and `or` expressions are now supported. * Unary operator processing is a lot more robust, now handling the `not` expression and `~x` on Bool. * 📢 The compiler now generates debug information for use with GDB/LLDB that describes variables and functions. * The first version of the Mojo Visual Studio Code extension has been released! It supports syntax highlighting for Mojo files. * The first version of the `Bool` type has landed in the new Mojo standard library! * 📢 Implicit conversions are now supported in return statements. ### Week of 2022-12-05 * "Discard" patterns are now supported, e.g. `_ = foo()` * We now support implicit conversions in function call arguments, e.g. converting an `index` value to `Int` automatically. This eliminates a bunch of casts, e.g. the need to say F32(1.0) everywhere. This is limited for a few reasons that will be improved later: 1. We don't support overloading, so lots of types aren't convertible from all the things they should be, e.g. you can't pass "1" to something that expects F32, because F32 can't be created from index. 2. This doesn't "check to see if we can invoke `__new__`" it force applies it on a mismatch, which leads to poor QoI. 3. This doesn't fix things that need radd. ## November 2022 ### Week of 2022-11-28 * 📢 We support the `True` and `False` keywords as expressions. * 📢 A new `alias` declaration is supported which allows defining local parameter values. This will eventually subsume type aliases and other things as it gets built out. * 📢 We now have end-to-end execution of Mojo files using the `kgen` tool! Functions exported with `@export` can be executed. * 📢 We have try-except-else and `raise` statements and implicit error propagation! The error semantics are that `def` can raise by default, but `fn` must explicitly declare raising with a `@raises` decorator. Stub out basic `Error` type. * The `&` sigil for by-ref arguments is now specified after the identifier. Postfix works better for ref and move operators on the expression side because it chains an mentally associates correctly: `thing.method().result^`. We don't do that yet, but align param decl syntax to it so that things won't be odd looking when we do. In practice this looks like: ```mojo def mutate_argument(a&: index): a = 25 ``` ### Week of 2022-11-21 * 📢 The magic `index` type is gone. Long live `__mlir_type.index`. * Implement parameter substitution into parametric `__mlir_type` decls. This allows us to define parametric opaque MLIR types with exposed parameters using a new "placeholder" attribute. This allows us to expose the power of the KGEN type parametric system directly into Mojo. * 📢 Fully-parametric custom types can now be defined and work in Mojo, bringing together a lot of the recent work. We can write the SIMD type directly as a wrapper around the KGEN type, for example: ```mojo struct SIMD[dt: __mlir_type.`!kgen.dtype`, nelts: __mlir_type.index]: var value: __mlir_type.`!pop.simd<#lit, #lit>`[nelts, dt] fn __add__(self, rhs: SIMD[dt, nelts]) -> SIMD[dt, nelts]: return __mlir_op.`pop.add`(self.value, rhs.value) ``` ### Week of 2022-11-14 * 📢 Implement a magic `__mlir_type` declaration that can be used to access any MLIR type. E.g. `__mlir_type.f64`. * 📢 Add an `fn` declaration. These are like `def` declarations, but are more strict in a few ways: they require type annotations on arguments, don't allow implicit variable declarations in their body, and make their arguments rvalues instead of lvalues. * Implemented Swift-style backtick identifiers, which are useful for code migration where names may collide with new keywords. * 📢 A new `__include` directive has been added that performs source-level textual includes. This is temporary until we have an `import` model. * Implement IR generation for arithmetic operators like `+` and `*` in terms of the `__add__` and `__mul__` methods. * 📢 Added support for `break` and `continue` statements, as well as early returns inside loops and conditionals! * 📢 Implemented augmented assignment operators, like `+=` and `@=`. * 📢 Mojo now has access to generating any MLIR operations (without regions) with a new `__mlir_op` magic declaration. We can start to build out the language's builtin types with this: ```mojo struct Int: var value: __mlir_type.index fn __add__(self, rhs: Int) -> Int: return __mlir_op.`index.add`(self.value, rhs.value) ``` Attributes can be attached to the declaration with subscript `[]` syntax, and an explicit result type can be specified with a special `_type` attribute if it cannot be inferred. Attributes can be accessed via the `__mlir_attr` magic decl: ```mojo __mlir_op.`index.cmp`[ _type: __mlir_type.i1, pred: __mlir_attr.`#index` ](lhs, rhs) ``` * Improved diagnostics emissions with ranges! Now errors highlight the whole section of code and not just the first character. ### Week of 2022-11-07 * Implemented the `@interface` and `@implements` decorators, which provide access to KGEN generator interfaces. A function marked as an `@interface` has no body, but it can be implemented by multiple other functions. ```mojo @interface def add(lhs: index, rhs: index): @implements(add) def normal_add(lhs: index, rhs: index) -> index: return lhs + rhs @implements(add) def slow_add(lhs: index, rhs: index) -> index: wait(1000) return normal_add(lhs, rhs) ``` * 📢 Support for static struct methods and initializer syntax has been added. Initializing a struct with `Foo()` calls an implicitly static `__new__` method. This method should be used instead of `__init__` inside structs. ```mojo struct Foo: var value: index def __new__() -> Foo: var result: Foo result.value = Foo.return_a_number() # static method! return result @staticmethod def return_a_number() -> index: return 42 ``` * 📢 Full by-ref argument support. It's now possible to define in-place operators like `__iadd__` and functions like `swap(x, y)` correctly. * 📢 Implemented support for field extract from rvalues, like `x.value` where `x` is not an lvalue (`var` declaration or by-ref function argument). ## October 2022 ### Week of 2022-10-31 * Revised `return` handling so that a return statement with no expression is syntax sugar for `return None`. This enables early exits in functions that implicitly return `None` to be cleaner: ```mojo def just_return(): return ``` * Added support for parsing more expressions: if-else, bitwise operators, shift operators, comparisons, floor division, remainder, and matmul. * 📢 The type of the `self` argument can now be omitted on member methods. ### Week of 2022-10-24 * Added parser support for right-associativity and unary ops, like the power operator `a ** b ** c` and negation operator `-a`. * Add support for `&expr` in Mojo, which allows denoting a by-ref argument in functions. This is required because the `self` type of a struct method is implicitly a pointer. * Implemented support for parametric function declarations, such as: ```mojo struct SIMD[dt: DType, width: index]: fn struct_method(self: &SIMD[dt, width]): pass def fancy_add[dt: DType, width: index]( lhs: SIMD[dt, width], rhs: SIMD[dt, width]) -> index: return width ``` ### Week of 2022-10-17 * Added explicit variable declarations with `var`, for declaring variables both inside functions and structs, with support for type references. Added `index` as a temporary built-in type. ```mojo def foo(lhs: index, rhs: index) -> index: var result: index = lhs + rhs return result ``` * Implemented support for parsing struct declarations and references to type declarations in functions! In `def`, the type can be omitted to signal an object type. ```mojo struct Foo: var member: index def bar(x: Foo, obj) -> index: return x.member ``` * Implemented parser support for `if` statements and `while` loops! ```mojo def if_stmt(c: index, a: index, b: index) -> index: var result: index = 0 if c: result = a else: result = b return result def while_stmt(init: index): while init > 1: init = init - 1 ``` * Significantly improved error emission and handling, allowing the parser to emit multiple errors while parsing a file. ### Week of 2022-10-10 * Added support for parsing integer, float, and string literals. * Implemented parser support for function input parameters and results. You can now write parametric functions like, ```mojo def foo[param: Int](arg: Int) -> Int: result = param + arg return result ``` ### Week of 2022-10-03 * Added some basic parser scaffolding and initial parser productions, including trivial expressions and assignment parser productions. * Implemented basic scope handling and function IR generation, with support for forward declarations. Simple functions like, ```mojo def foo(x: Int): ``` Now parse! But all argument types are hard-coded to the MLIR `index` type. * Added IR emission for simple arithmetic expressions on builtin types, like `x + y`. ## September 2022 ### Week of 2022-09-26 * Mojo's first patch to add a lexer was Sep 27, 2022. * Settled on `[]` for Mojo generics instead of `<>`. Square brackets are consistent with Python generics and don't have the less than ambiguity other languages have. --- ## mojo build Builds an executable from a Mojo file. ## Synopsis ``` mojo build [options] ``` ## Description Compiles the Mojo file at the given path into an executable. By default, the executable is saved to the current directory and named the same as the input file, but without a file extension. Beware that any Python libraries used in your Mojo project are not included in the executable binary, so they must be provided by the environment where you run the executable. ## Options ### Output options #### `-o ` Sets the path and filename for the executable output. By default, it outputs the executable to the current directory, with the same name and no extension. #### `--emit ` The type of output file to generate. * `exe` (default): emit an executable binary file. * `shared-lib`: emit a shared (dynamic) library. * `object`: (EXPERIMENTAL) emit a single object file. * `llvm`: emit unoptimized LLVM IR. * `llvm-bitcode`: emit bitcode of unoptimized LLVM IR. * `asm`: emit target assembly. ### Compilation options #### `--optimization-level `, `-O`, `--no-optimization (LEVEL=0)` Sets the level of optimization to use at compilation. The value must be a number between 0 and 3. The default is 3. #### `-I ` Appends the given path to the list of directories to search for imported Mojo files. #### `-D ` Defines a named value that can be used from within the Mojo source file being executed. For example, `-Dfoo=42` defines a name `foo` that, when queried with the `sys.param_env` module from within the Mojo program, would yield the compile-time value `42`. #### `--debug-level `, `-g (LEVEL=full)`, `-g0 (LEVEL=none)`, `-g1 (LEVEL=line-tables)`, `-g2 (LEVEL=full)` Sets the level of debug info to use at compilation. The value must be one of: `none` (the default value), `line-tables`, or `full`. Please note that there are issues when generating debug info for some Mojo programs that have yet to be addressed. #### `--num-threads `, `-j` Sets the maximum number of threads to use for compilation. The default is 0 (use all available threads). #### `--elaboration-error-include-prelude` Show elaboration error with locations in mojo startup modules (prelude). ### Target options #### `--target-triple ` Sets the compilation target triple. Defaults to the host target. #### `--target-cpu ` Sets the compilation target CPU. Defaults to the host CPU. #### `--target-features ` Sets the compilation target CPU features. Defaults to the host features. #### `--march ` Sets the architecture for which to generate code. #### `--mcpu ` Sets the CPU for which to generate code. #### `--mtune ` Sets the CPU for which to tune code. #### `--target-accelerator ` Sets the GPU or accelerator architecture for heterogeneous computing (e.g., sm\_90 for NVIDIA H100, gfx942 for AMD MI300). #### `--print-effective-target` Print the effective target configuration after absorbing all command-line flags and exit. #### `--print-supported-targets` Print all available target names and exit. #### `--print-supported-cpus` Print valid CPU names for the specified target and exit. Requires --target-triple. #### `--print-supported-accelerators` Print all supported GPU and accelerator architectures and exit. ### Compilation diagnostic options Controls how the Mojo compiler outputs diagnostics related to compiling and running Mojo source code. #### `--diagnose-missing-doc-strings` Emits diagnostics for missing or partial doc strings. #### `--max-notes-per-diagnostic ` When the Mojo compiler emits diagnostics, it sometimes also prints notes with additional information. This option sets an upper threshold on the number of notes that can be printed with a diagnostic. If not specified, the default maximum is 10. #### `--disable-builtins` Do not use builtins when create package. #### `--disable-warnings` Do not print warning messages. #### `--experimental-fixit` Automatically apply fix-its to the code, and rerun the command again after the fix-its are applied. WARNING: this feature is highly experimental and may result in irreversible data loss. #### `--experimental-export-fixit ` Export fix-its to a YAML file in clang-tidy format instead of applying them directly. The file can be applied using 'clang-apply-replacements'. WARNING: this feature is highly experimental. #### `--Werror` Treat warnings as errors. #### `--Wno-error` Do not treat warnings as errors. #### `--warn-on-unstable-apis` Warn when using unstable APIs from the standard library. ### Linker options #### `-Xlinker ` Pass ARG to the linker. ### Experimental compilation options #### `--sanitize ` Turns on runtime checks. The following values are supported: `address` (detects memory issues), and `thread` (detects multi-threading issues). #### `--shared-libasan` Dynamically link the address sanitizer runtime. Requires address sanitization turned on with `--sanitize` option. #### `--debug-info-language ` Sets the language to emit as part of the debug info. The supported languages are: `Mojo`, and `C`. `Mojo` is the default. `C` is useful to enable rudimentary debugging and binary introspection in tools that don't understand Mojo, but is not required for `mojo debug`. ### Common options #### `--diagnostic-format ` The format in which diagnostics and error messages are printed. Must be one of "text" or "json" ("text" is the default). #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo debug Launches the Mojo debugger using the command-line interface or an external editor. ## Synopsis ``` mojo debug [debug-options] ``` ## Description This command, which underneath uses the LLDB debugger, or cuda-gdb, offers four basic debug session modes: * Build and debug a Mojo file. ``` mojo debug [options] [runtime args] ``` Builds the Mojo file at the given path and launches it under the debugger. Options, which come before the Mojo file, can include any compilation options expected by the `mojo run`, as well as regular debuggingcommands. Runtime args, which come after the Mojo file, are passed directly to the debuggee upon launch. By default, this mode uses `-O0` and `--debug-level=full` as compilation options. * Debug a precompiled program. ``` mojo debug [options] [runtime args] ``` Launches the program at the given path in the debugger. Options, which come before the program path, cannot include compilation commands. Runtime args, which come after the program path, are passed directly to the debuggee upon launch. * Attach to a running process. ``` mojo debug [options] [--pid | --process-name ] ``` Attaches to the process specified by pid or name, which can be the full path of the process' executable. Options other than the process identifier cannot include compilation options. * Start the debugger command-line interface. ``` mojo debug [options] ``` Launches the debugger CLI with support for debugging Mojo programs. This command only supports LLDB or cuda-gdb options via the `--X` option. You can also select one of two interfaces for the debug session: * CLI: By default, all debug session modes are launched using the regular debugger command-line interface. * VS Code Debug Server: If you add the `--vscode` option, the debug session is launched in VS Code via the Mojo extension. VS Code must be running and the Mojo extension must be enabled. Besides that, the environment variables and the current working directory of this invocation are preserved when launching programs in the debugger on VS Code. Finally, it is worth mentioning that this debugger can debug programs written in other standard native languages like Rust, C and C++, as it is based on LLDB or cuda-gdb. Debugger capabilities: * LLDB: this is the default debugger and has great support for CPU Mojo code, but has no support at all for Mojo GPU code. * cuda-gdb: this is invoked via the `--cuda-gdb` option and has minimal support for CPU Mojo code but it has support for GPU Mojo code. ## Options ### Attach options #### `--pid ` Indicates the debugger to attach to the process with the given PID. #### `--process-name ` Indicates the debugger to attach to the process with the given name or path. ### cuda-gdb options #### `--cuda-gdb` Uses cuda-gdb instead of LLDB for debugging. In this mode, it's possible to step into GPU code, but the CPU debugging experience is degraded. #### `--cuda-gdb-path ` Uses the given CUDA\_GDB\_PATH instead of looking for cuda-gdb in the PATH environment variable. #### `--break-on-launch` Set the breakOnLaunch option for cuda-gdb. This makes the debugger break on the first instruction of every launched kernel. ### Compilation options #### `--optimization-level `, `-O`, `--no-optimization (LEVEL=0)` Sets the level of optimization to use at compilation. The value must be a number between 0 and 3. The default is 3. #### `-I ` Appends the given path to the list of directories to search for imported Mojo files. #### `-D ` Defines a named value that can be used from within the Mojo source file being executed. For example, `-Dfoo=42` defines a name `foo` that, when queried with the `sys.param_env` module from within the Mojo program, would yield the compile-time value `42`. #### `--debug-level `, `-g (LEVEL=full)`, `-g0 (LEVEL=none)`, `-g1 (LEVEL=line-tables)`, `-g2 (LEVEL=full)` Sets the level of debug info to use at compilation. The value must be one of: `none` (the default value), `line-tables`, or `full`. Please note that there are issues when generating debug info for some Mojo programs that have yet to be addressed. #### `--num-threads `, `-j` Sets the maximum number of threads to use for compilation. The default is 0 (use all available threads). #### `--elaboration-error-include-prelude` Show elaboration error with locations in mojo startup modules (prelude). ### Target options #### `--target-triple ` Sets the compilation target triple. Defaults to the host target. #### `--target-cpu ` Sets the compilation target CPU. Defaults to the host CPU. #### `--target-features ` Sets the compilation target CPU features. Defaults to the host features. #### `--march ` Sets the architecture for which to generate code. #### `--mcpu ` Sets the CPU for which to generate code. #### `--mtune ` Sets the CPU for which to tune code. #### `--target-accelerator ` Sets the GPU or accelerator architecture for heterogeneous computing (e.g., sm\_90 for NVIDIA H100, gfx942 for AMD MI300). #### `--print-effective-target` Print the effective target configuration after absorbing all command-line flags and exit. #### `--print-supported-targets` Print all available target names and exit. #### `--print-supported-cpus` Print valid CPU names for the specified target and exit. Requires --target-triple. #### `--print-supported-accelerators` Print all supported GPU and accelerator architectures and exit. ### Compilation diagnostic options Controls how the Mojo compiler outputs diagnostics related to compiling and running Mojo source code. #### `--diagnose-missing-doc-strings` Emits diagnostics for missing or partial doc strings. #### `--max-notes-per-diagnostic ` When the Mojo compiler emits diagnostics, it sometimes also prints notes with additional information. This option sets an upper threshold on the number of notes that can be printed with a diagnostic. If not specified, the default maximum is 10. #### `--disable-builtins` Do not use builtins when create package. #### `--disable-warnings` Do not print warning messages. #### `--experimental-fixit` Automatically apply fix-its to the code, and rerun the command again after the fix-its are applied. WARNING: this feature is highly experimental and may result in irreversible data loss. #### `--experimental-export-fixit ` Export fix-its to a YAML file in clang-tidy format instead of applying them directly. The file can be applied using 'clang-apply-replacements'. WARNING: this feature is highly experimental. #### `--Werror` Treat warnings as errors. #### `--Wno-error` Do not treat warnings as errors. #### `--warn-on-unstable-apis` Warn when using unstable APIs from the standard library. ### Debugger options #### `--X ` Passes ARG as an argument to the debugger when the debug session is launched using the debugger command-line interface. This option can be specified multiple times. It is ignored when using the RPC mode. ### Debug server options #### `--vscode` Launches the debug session on VS Code via the Mojo extension. #### `--rpc` Alias for --vscode. #### `--terminal ` The type of terminal to use when starting a launch debug session. * `console` (default): the debuggee will be launched in the default environment for the editor. If using VS Code, this will be the Debug Console. * `dedicated`: the debuggee will be launched in a dedicated terminal within the editor. #### `--port ` Uses the given PORT to communicate with the RPC debug server. Defaults to trying all ports from 12355 to 12364 inclusive. #### `--stop-on-entry` Automatically stop after launch. #### `--init-command ` Initialization command executed upon debugger startup. Can be specified multiple times. ### Linker options #### `-Xlinker ` Pass ARG to the linker. ### Experimental compilation options #### `--sanitize ` Turns on runtime checks. The following values are supported: `address` (detects memory issues), and `thread` (detects multi-threading issues). #### `--shared-libasan` Dynamically link the address sanitizer runtime. Requires address sanitization turned on with `--sanitize` option. #### `--debug-info-language ` Sets the language to emit as part of the debug info. The supported languages are: `Mojo`, and `C`. `Mojo` is the default. `C` is useful to enable rudimentary debugging and binary introspection in tools that don't understand Mojo, but is not required for `mojo debug`. ### Common options #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo demangle Demangles the given name. ## Synopsis ``` mojo demangle [options] ``` ## Description If the given name is a mangled Mojo symbol name, prints the demangled name. If no name is provided, one is read from standard input. ## Options ### Common options #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo doc Compiles docstrings from a Mojo file. ## Synopsis ``` mojo doc [options] ``` ## Description This is an early version of a documentation tool that generates an API reference from Mojo code comments. Currently, it generates a structured output of all docstrings into a JSON file, and it does not generate HTML. This output format is subject to change. The input may be a single file or a directory. If you specify a directory, it will generate a single JSON output with documentation for all modules found in that path, recursively. ## Options ### Output options #### `-o ` Sets the path and filename for the JSON output. If not provided, output is written to stdout. ### Compilation options #### `-I ` Appends the given path to the list of directories that Mojo will search for any package/module dependencies. That is, if the file you pass to `mojo doc` imports any packages that do not reside in the local path and are not part of the Mojo standard library, use this to specify the path where Mojo can find those packages. ### Validation options The following validation options help ensure that your docstrings use valid structure and meet other style criteria. By default, warnings are emitted only if the docstrings contain errors that prevent translation to the output format. (More options coming later.) #### `--diagnose-missing-doc-strings` Emits diagnostic warnings for missing or partial doc strings. #### `--docs-base-path ` Sets the path prefix for generated documentation links. ### Compilation diagnostic options Controls how the Mojo compiler outputs diagnostics related to compiling and running Mojo source code. #### `--max-notes-per-diagnostic ` When the Mojo compiler emits diagnostics, it sometimes also prints notes with additional information. This option sets an upper threshold on the number of notes that can be printed with a diagnostic. If not specified, the default maximum is 10. #### `--Werror` Treat warnings as errors. #### `--Wno-error` Do not treat warnings as errors (overrides -Werror). ### Common options #### `--diagnostic-format ` The format in which diagnostics and error messages are printed. Must be one of "text" or "json" ("text" is the default). #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo format Formats Mojo source files. ## Synopsis ``` mojo format [options] ``` ## Description Formats the given set of Mojo sources using a Mojo-specific lint tool. ## Options ### Format options #### `--line-length `, `-l ` Sets the max character line length. Default is 80. ### Diagnostic options #### `--quiet`, `-q` Disables non-error messages. ### Common options #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo The Mojo🔥 command line interface. ## Synopsis ``` mojo mojo [run-options] mojo [options] mojo ``` ## Description The `mojo` CLI provides all the tools you need for Mojo development, such as commands to run, compile, and package Mojo code. A list of all commands are listed below, and you can learn more about each one by adding the `--help` option to the command (for example, `mojo package --help`). However, you may omit the `run` and `repl` commands. That is, you can run a Mojo file by simply passing the filename to `mojo`: ``` mojo hello.mojo ``` And you can start a REPL session by running `mojo` with no commands. You can check your current version with `mojo --version`. For version information, see the [Mojo changelog](/mojo/changelog). ## Commands [`run`](run.md) — Builds and executes a Mojo file. [`build`](build.md) — Builds an executable from a Mojo file. [`repl`](repl.md) — Launches the Mojo REPL. [`debug`](debug.md) — Launches the Mojo debugger using the command-line interface or an external editor. [`package`](package.md) — Compiles a Mojo package. [`format`](format.md) — Formats Mojo source files. [`doc`](doc.md) — Compiles docstrings from a Mojo file. [`demangle`](demangle.md) — Demangles the given name. ## Options ### Diagnostic options #### `--version`, `-v` Prints the Mojo version and exits. ### Common options #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo package Compiles a Mojo package. ## Synopsis ``` mojo package [options] ``` ## Description Compiles a directory of Mojo source files into a binary package suitable to share and import into other Mojo programs and modules. A Mojo package is portable across different systems because it includes only non-elaborated code (it's not an arch-specific package). The code becomes an arch-specific executable only after it's imported into a Mojo program that's then compiled with `mojo build`. To create a Mojo package, first add an `__init__.mojo` file to your package directory. Then pass that directory name to this command, and specify the output path and filename with `-o`. For more information, see [Mojo modules and packages](/mojo/manual/packages). ## Options ### Output options #### `-o ` Sets the path and filename for the output package. The filename must end with either `.mojopkg` or `.📦`. The filename given here defines the package name you can then use to import the code (minus the file extension). If you don't specify this option, a `.mojopkg` file is generated in the current working directory, with a name based on the name of the input directory. ### Compilation options #### `-I ` Appends the given path to the list of directories to search for imported Mojo files. #### `-kgenModule` Export as a KGEN module. ### Compilation diagnostic options Controls how the Mojo compiler outputs diagnostics related to compiling and running Mojo source code. #### `--diagnose-missing-doc-strings` Emits diagnostics for missing or partial doc strings. #### `--max-notes-per-diagnostic ` When the Mojo compiler emits diagnostics, it sometimes also prints notes with additional information. This option sets an upper threshold on the number of notes that can be printed with a diagnostic. If not specified, the default maximum is 10. #### `--disable-builtins` Do not use builtins when create package. #### `--disable-warnings` Do not print warning messages. #### `--experimental-fixit` Automatically apply fix-its to the code, and rerun the command again after the fix-its are applied. WARNING: this feature is highly experimental and may result in irreversible data loss. #### `--experimental-export-fixit ` Export fix-its to a YAML file in clang-tidy format instead of applying them directly. The file can be applied using 'clang-apply-replacements'. WARNING: this feature is highly experimental. #### `--Werror` Treat warnings as errors. #### `--Wno-error` Do not treat warnings as errors. #### `--warn-on-unstable-apis` Warn when using unstable APIs from the standard library. ### Common options #### `--diagnostic-format ` The format in which diagnostics and error messages are printed. Must be one of "text" or "json" ("text" is the default). #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo repl Launches the Mojo REPL. ## Synopsis ``` mojo repl [lldb-options] ``` ## Description Launches a Mojo read-evaluate-print loop (REPL) environment, which provides interactive development in the terminal. You can also start the REPL by simply running `mojo`. Any number of options and arguments may be specified on the command line. These are then forwarded to the underlying lldb tool, which runs the REPL. ## Options ### Common options #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## mojo run Builds and executes a Mojo file. ## Synopsis ``` mojo run [options] [path-arguments...] ``` ## Description Compiles the Mojo file at the given path and immediately executes it. Another way to execute this command is to simply pass a file to `mojo`. For example: ``` mojo hello.mojo ``` Options for this command itself, such as the ones listed below, must appear before the input file `path` argument. Any command line arguments that appear after the Mojo source file `path` are interpreted as arguments for that Mojo program. ## Options ### Compilation options #### `--optimization-level `, `-O`, `--no-optimization (LEVEL=0)` Sets the level of optimization to use at compilation. The value must be a number between 0 and 3. The default is 3. #### `-I ` Appends the given path to the list of directories to search for imported Mojo files. #### `-D ` Defines a named value that can be used from within the Mojo source file being executed. For example, `-Dfoo=42` defines a name `foo` that, when queried with the `sys.param_env` module from within the Mojo program, would yield the compile-time value `42`. #### `--debug-level `, `-g (LEVEL=full)`, `-g0 (LEVEL=none)`, `-g1 (LEVEL=line-tables)`, `-g2 (LEVEL=full)` Sets the level of debug info to use at compilation. The value must be one of: `none` (the default value), `line-tables`, or `full`. Please note that there are issues when generating debug info for some Mojo programs that have yet to be addressed. #### `--num-threads `, `-j` Sets the maximum number of threads to use for compilation. The default is 0 (use all available threads). #### `--elaboration-error-include-prelude` Show elaboration error with locations in mojo startup modules (prelude). ### Target options #### `--target-triple ` Sets the compilation target triple. Defaults to the host target. #### `--target-cpu ` Sets the compilation target CPU. Defaults to the host CPU. #### `--target-features ` Sets the compilation target CPU features. Defaults to the host features. #### `--march ` Sets the architecture for which to generate code. #### `--mcpu ` Sets the CPU for which to generate code. #### `--mtune ` Sets the CPU for which to tune code. #### `--target-accelerator ` Sets the GPU or accelerator architecture for heterogeneous computing (e.g., sm\_90 for NVIDIA H100, gfx942 for AMD MI300). #### `--print-effective-target` Print the effective target configuration after absorbing all command-line flags and exit. #### `--print-supported-targets` Print all available target names and exit. #### `--print-supported-cpus` Print valid CPU names for the specified target and exit. Requires --target-triple. #### `--print-supported-accelerators` Print all supported GPU and accelerator architectures and exit. ### Compilation diagnostic options Controls how the Mojo compiler outputs diagnostics related to compiling and running Mojo source code. #### `--diagnose-missing-doc-strings` Emits diagnostics for missing or partial doc strings. #### `--max-notes-per-diagnostic ` When the Mojo compiler emits diagnostics, it sometimes also prints notes with additional information. This option sets an upper threshold on the number of notes that can be printed with a diagnostic. If not specified, the default maximum is 10. #### `--disable-builtins` Do not use builtins when create package. #### `--disable-warnings` Do not print warning messages. #### `--experimental-fixit` Automatically apply fix-its to the code, and rerun the command again after the fix-its are applied. WARNING: this feature is highly experimental and may result in irreversible data loss. #### `--experimental-export-fixit ` Export fix-its to a YAML file in clang-tidy format instead of applying them directly. The file can be applied using 'clang-apply-replacements'. WARNING: this feature is highly experimental. #### `--Werror` Treat warnings as errors. #### `--Wno-error` Do not treat warnings as errors. #### `--warn-on-unstable-apis` Warn when using unstable APIs from the standard library. ### Linker options #### `-Xlinker ` Pass ARG to the linker. ### Experimental compilation options #### `--sanitize ` Turns on runtime checks. The following values are supported: `address` (detects memory issues), and `thread` (detects multi-threading issues). #### `--shared-libasan` Dynamically link the address sanitizer runtime. Requires address sanitization turned on with `--sanitize` option. #### `--debug-info-language ` Sets the language to emit as part of the debug info. The supported languages are: `Mojo`, and `C`. `Mojo` is the default. `C` is useful to enable rudimentary debugging and binary introspection in tools that don't understand Mojo, but is not required for `mojo debug`. ### Common options #### `--diagnostic-format ` The format in which diagnostics and error messages are printed. Must be one of "text" or "json" ("text" is the default). #### `--help`, `-h` Displays help information. #### `--help-hidden` Displays help for hidden options. --- ## Mojo FAQ We tried to anticipate your questions about Mojo on this page. If this page doesn't answer all your questions, also check out our [community channels](https://www.modular.com/community). ## Motivation ### Why did you build Mojo? We built Mojo to solve an internal challenge when building the [Modular Platform](https://www.modular.com)—programming across the entire stack was too complicated. We wanted a flexible and scalable programming model that could target CPUs, GPUs, AI accelerators, and other heterogeneous systems that are pervasive in the AI field. This meant a programming language with powerful compile-time metaprogramming, integration of adaptive compilation techniques, caching throughout the compilation flow, and other features that are not supported by existing languages. As a result, we're extremely committed to Mojo's long term success and are investing heavily in it. Our overall mission is to unify AI software and we can’t do that without a unified language that can scale across the whole AI infrastructure stack. Our current focus is to unify CPU and GPU programming with blazing-fast execution for the Modular Platform. That said, the north star is for Mojo to support the whole gamut of general-purpose programming over time. For more detail, see the [Mojo vision](/mojo/vision). ### Why is it called Mojo? Mojo means "a magical charm" or "magical powers." We thought this was a fitting name for a language that brings magical powers to Python, including unlocking an innovative programming model for accelerators and other heterogeneous systems pervasive in AI today. ### Why does Mojo have the 🔥 file extension? We paired Mojo with fire emoji 🔥 as a fun visual way to impart onto users that Mojo empowers them to get their Mojo on—to develop faster and more efficiently than ever before. We also believe that the world can handle a unicode extension at this point, but you can also just use the `.mojo` extension. :) ### What problems does Mojo solve that no other language can? Mojo combines the usability of Python with the systems programming features it’s missing. We are guided more by pragmatism than novelty, but Mojo’s use of [MLIR](https://mlir.llvm.org/) allows it to scale to new exotic hardware types and domains in a way that other languages haven’t demonstrated. It also has caching and distributed compilation built into its core. We also believe Mojo has a good chance of unifying hybrid packages in the broader Python community. ### What kind of developers will benefit the most from Mojo? Mojo’s initial focus is to bring programmability back to AI, enabling AI developers to customize and get the most out of their hardware. As such, Mojo will primarily benefit researchers and other engineers looking to write high-performance AI operations. Over time, Mojo will become much more interesting to the general Python community as it grows to be a superset of Python. We hope this will help lift the vast Python library ecosystem and empower more traditional systems developers that use C, C++, Rust, etc. ### Why build upon Python? Effectively, all AI research and model development happens in Python today, and there’s a good reason for this! Python is a powerful high-level language with clean, simple syntax and a massive ecosystem of libraries. At Modular, one of our core principles is meeting customers where they are—our goal is not to further fragment the AI landscape but to unify and simplify AI development workflows. Our focus is to innovate in the programmability for AI workloads on heterogeneous hardware, and we don't see any need to innovate in language *syntax* or *community*. So we chose to embrace the Python ecosystem because it's so widely used, it's loved by the AI ecosystem, and because we believe it is a really nice language. ### Why not enhance CPython (the major Python implementation) instead? For a variety of reasons, Python isn't suitable for systems programming. Python has amazing strengths as a glue layer—it offers low-level bindings that allow developers to build libraries in C, C++ and many other languages that have better performance characteristics. This enables things like NumPy and PyTorch, and a vast number of other libraries in the AI ecosystem, but it comes with a cost. Building these hybrid libraries is very complicated. It requires a deep understanding of CPython and strong C/C++ (or other) programming abilities (undermining one of the original goals of using Python in the first place). These hybrid-language libraries also create problems for the library users, because debuggers generally can't step between Python and C/C++ code. We’re thrilled to see a big push to improve the performance of [CPython](https://en.wikipedia.org/wiki/CPython), but our goals for Mojo (such as to deploy onto GPUs and other accelerators) requires a fundamentally different architecture and compiler approach. That said, CPython is still a critical part of our compatibility approach and powers [Mojo's Python interoperability](/mojo/manual/python). ### Why not enhance another Python implementation (like Codon, PyPy, etc)? Codon and PyPy aim to improve performance compared to CPython, but Mojo’s goals are much deeper than this. Our objective isn’t just to create "a faster Python," but to enable a whole new layer of systems programming that includes direct access to accelerated hardware. Many hardware accelerators support very limited dynamic features, or do so with terrible performance. Furthermore, systems programmers don't seek only "performance," but also demand a lot of predictability and control over how a computation happens, so in some cases we cannot accept dynamic features at all. Furthermore, solving big challenges for the computing industry is hard and requires a fundamental rethinking of the compiler and runtime infrastructure. This drove us to build an entirely new approach and we’re willing to put in the time required to do it properly, rather than tweaking an existing system that would only solve a small part of the problem. For more detail, see our blog post about [How Modular is Democratizing AI Compute](https://www.modular.com/blog/how-is-modular-democratizing-ai-compute). ### Why not make Julia better? We think [Julia](https://julialang.org/) is a great language and it has a wonderful community, but Mojo is completely different. While Julia and Mojo might share some goals and look similar as an easy-to-use and high-performance alternative to Python, we’re taking a completely different approach to building Mojo. Notably, Mojo is Python-first and doesn't require existing Python developers to learn a new syntax. Mojo also has a bunch of technical advancements compared to Julia, simply because Mojo is newer and we’ve been able to learn from Julia (and from Swift, Rust, C++ and many others that came before us). For example, Mojo takes a different approach to memory ownership and memory management, it scales down to smaller envelopes, and is designed with AI and MLIR-first principles (though Mojo is not only for AI). That said, we also believe there’s plenty of room for many languages and this isn’t an OR proposition. If you use and love Julia, that's great! We’d love for you to try Mojo and if you find it useful, then that's great too. ## Functionality ### Where can I learn more about Mojo’s features? The best place to start is the [Mojo Manual](/mojo/manual). And if you want to see what features are coming in the future, take a look at [the roadmap](/mojo/roadmap). ### What are the benefits of building Mojo with MLIR? When we realized that no existing language could solve the challenges in AI compute, we embarked on a first-principles rethinking of how a programming language should be designed and implemented to solve our problems. Because we require high-performance support for a wide variety of accelerators, traditional compiler technologies like LLVM and GCC were not suitable (and any languages and tools based on them would not suffice). Although they support a wide range of CPUs and some commonly used GPUs, these compiler technologies were designed decades ago and are unable to fully support modern chip architectures. Nowadays, the standard technology for specialized machine learning accelerators is MLIR. [MLIR](https://mlir.llvm.org/) provides a flexible infrastructure for building compilers. It’s based upon layers of intermediate representations (IRs) that allow for progressive lowering of any code for any hardware, and it has been widely adopted by the hardware accelerator industry since [its first release](https://blog.google/technology/ai/mlir-accelerating-ai-open-source-infrastructure/). Its greatest strength is its ability to build *domain specific* compilers, particularly for weird domains that aren’t traditional CPUs and GPUs, such as AI ASICS, [quantum computing systems](https://github.com/PennyLaneAI/catalyst), FPGAs, and [custom silicon](https://circt.llvm.org/). Although you can use MLIR to create a flexible and powerful compiler for any programming language, Mojo is the world’s first language to be built from the ground up with MLIR design principles. This means that Mojo not only offers high-performance compilation for heterogeneous hardware, but it also provides direct programming support for the MLIR intermediate representations, which currently isn't possible with any other language. ### Is Mojo only for AI or can it be used for other stuff? Mojo's initial focus is to solve AI programmability challenges. However, our goal is to grow Mojo into a general purpose programming language. We use Mojo at Modular to develop AI algorithms and [GPU kernels](/max/tutorials/custom-ops-matmul), but you can use it for other things like HPC, data transformations, writing pre/post processing operations, and much more. ### Is Mojo interpreted or compiled? Mojo is a compiled language. [`mojo build`](/mojo/cli/build) performs ahead-of-time (AOT) compilation to save an executable program. [`mojo run`](/mojo/cli/run) performs just-in-time (JIT) compilation to execute a Mojo source file without saving the compiled result. ### How does Mojo compare to Triton Lang? [Triton Lang](https://triton-lang.org/main/index.html) is a specialized programming model for one type of accelerator, whereas Mojo is a more general language that will support more architectures over time and includes a debugger, a full tool suite, etc. For more about our thoughts on embedded domain-specific languages (EDSLs) like Triton, read [Democratizing AI Compute, Part 7](https://www.modular.com/blog/democratizing-ai-compute-part-7-what-about-triton-and-python-edsls). ### Does Mojo support distributed execution? Not alone. Mojo is one component of the Modular Platform, which makes it easier for you to author highly performant, portable CPU and GPU graph operations, but you’ll also need a runtime (or "OS") that supports graph level transformations and heterogeneous compute, which is provided by [MAX](/max/intro#components). ### How do I convert Python programs or libraries to Mojo? You can migrate parts of a Python project to Mojo by building Mojo bindings for Python. See the documentation about how to [call Mojo from Python](/mojo/manual/python/mojo-from-python). ### What about interoperability with other languages like C/C++? Yes, we want to enable developers to port code from languages other than Python to Mojo as well. We expect that due to Mojo’s similarity to the C/C++ type systems, migrating code from C/C++ should work well and it’s in [our roadmap](/mojo/roadmap#cc-interop). ### How does Mojo support hardware lowering? Mojo leverages LLVM-level dialects for the hardware targets it supports, and it uses other MLIR-based code-generation backends where applicable. This also means that Mojo is easily extensible to any hardware backend. ### Who writes the software to add more hardware support for Mojo? Mojo provides all the language functionality necessary for anyone to extend hardware support. As such, we expect hardware vendors and community members will contribute additional hardware support in the future. ## Performance ### Are there any AI related performance benchmarks for Mojo? It’s important to remember that Mojo is designed to be a general-purpose programming language, and any AI-related benchmarks will rely heavily upon other framework components. For example, our in-house CPU and GPU graph operations that power the Modular Platform are all written in Mojo and you can learn more about performance in our [matrix multiplication blog post](https://www.modular.com/blog/the-worlds-fastest-unified-matrix-multiplication). For details about our end-to-end model performance, read about [how we measure performance at Modular](https://www.modular.com/blog/max-gpu-state-of-the-art-throughput-on-a-new-genai-platform). ## Mojo SDK ### How can I get the Mojo SDK? You can get Mojo and all the developer tools by installing `mojo` with any Python or Conda package manager. For details, see the [Mojo installation guide](/mojo/manual/install). ### Is the Mojo Playground still available? No. We shut it down with the v25.6 release. Here's the story: When we announced Mojo in May, 2023, Mojo wasn't available in an SDK; it was available only in web-hosted a JupyterLab environment. After we made Mojo available for local development, we shut down the JupyterLab environment and launched a new Mojo Playground for people to try Mojo on the web. But ever since we made the Mojo SDK available for Linux and Mac, Mojo Playground usage steadily declined. The trickle of users we get now no longer justifies the maintenance and hosting costs. See how to [install Mojo](/mojo/manual/install). ### What are the license terms for the SDK? Please read the [Terms of use](https://www.modular.com/legal/terms). ### What operating systems are supported? Mac and Linux. For details, see the [Mojo system requirements](/mojo/manual/install#system-requirements). ### Is there IDE Integration? Yes, we've published an official Mojo language extension for [Visual Studio Code](https://code.visualstudio.com/) and other editors that support VS Code extensions (such as [Cursor](https://cursor.com/home)). The extension supports various features including syntax highlighting, code completion, formatting, hover, etc. It works seamlessly with remote-ssh and dev containers to enable remote development in Mojo. You can obtain the extension from either the [Visual Studio Code Marketplace](https://marketplace.visualstudio.com/items?itemName=modular-mojotools.vscode-mojo) or the [Open VSX Registry](https://open-vsx.org/extension/modular-mojotools/vscode-mojo). See [Add the VS Code extension](/mojo/manual/install/#add-the-vs-code-extension) for more information. ### Does the Mojo SDK collect telemetry? Yes, the Mojo SDK collects some basic system information, crash reports, and some LSP events that enable us to identify, analyze, and prioritize Mojo issues. v25.6 and earlier versions also collected compiler/runtime events, but we've since removed them. Specifically, we collect: * **Crash reports**: When the Mojo compiler crashes with a stack trace, the only information used in the report is the OS version and MAX/Mojo version. * **LSP performance metrics**: The Mojo LSP reports aggregate data on how long it takes to respond to user input (parsing latency). The only information used in the report is the milliseconds between user keystrokes and when the Mojo LSP is able to show appropriate error or warning messages. No user information, such as source code, keystrokes, or any other user data, is ever collected or transmitted. This telemetry is crucial to help us quickly identify problems and improve our products. Without this telemetry, we would have to rely on user-submitted bug reports, and in our decades of experience building developer products, we know that most people don't do that. The telemetry provides us the insights we need to build better products for you. ## Versioning & compatibility ### What’s the Mojo versioning strategy? Mojo is still in early development and not at a 1.0 version yet. It’s still missing many foundational features, but please take a look at our [roadmap](/mojo/roadmap) to understand where things are headed. As such, the language is evolving rapidly and source stability is not guaranteed. ### How often will you be releasing new versions of Mojo? Mojo development is moving fast and we are regularly releasing updates, including nightly builds almost every day. Join the [Mojo Discord channel](http://discord.gg/modular) for notifications and [sign up for our newsletter](https://www.modular.com/modverse#signup) for more coarse-grain updates. ## Open Source ### Will Mojo be open-sourced? We have committed to open-sourcing Mojo in 2026. Mojo is still young, so we will continue to incubate it within Modular until more of its internal architecture is fleshed out. ### Why not develop Mojo in the open from the beginning? Mojo is a big project and has several architectural differences from previous languages. We believe a tight-knit group of engineers with a common vision can move faster than a community effort. This development approach is also well-established from other projects that are now open source (such as LLVM, Clang, Swift, MLIR, etc.). ## Community ### Where can I ask more questions or share feedback? If you have questions about upcoming features or have suggestions for the language, be sure you first read the [Mojo roadmap](/mojo/roadmap), which provides important information about our current priorities. To get in touch with the Mojo team and developer community, use the resources on our [community page](https://www.modular.com/community). --- ## allgather
`allgather[dtype: DType, rank: Int, ngpus: Int](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], output_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctxs: List[DeviceContext], _max_num_blocks: Optional[Int] = None)` Performs all-gather across GPUs with variadic output. Each device receives individual copies of all input buffers. The implementation automatically selects between P2P and non-P2P paths based on hardware capabilities. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType - The data type of tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Int - Number of dimensions in input tensors. * ​ngpus ([`Int`](/mojo/std/builtin/int/Int)): Int - Number of GPUs participating in all-gather. **Args:** * ​input\_buffers ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Input buffers from each GPU. * ​output\_buffers ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Flat array of ngpus \* ngpus output buffers. Layout: output\_buffers\[device\_idx \* ngpus + input\_idx] contains device\_idx's copy of input\_idx's data. * ​rank\_sigs ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Signal pointers for P2P synchronization. * ​ctxs ([`List`](/mojo/std/collections/list/List)): List of device contexts for participating GPUs. * ​\_max\_num\_blocks ([`Optional`](/mojo/std/collections/optional/Optional)): Maximum number of blocks for kernel launch (optional). `allgather[dtype: DType, rank: Int, ngpus: Int](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], output_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], ctxs: List[DeviceContext])` Backward compatible version without rank\_sigs parameter. This version uses the naive implementation since we can't allocate signal buffers with proper lifetime in this function. **Deprecated:** Use the `signal_buffers` overload of `allgather` instead.
--- ## allgather (Allgather)
Multi-GPU allgather implementation that gathers values from multiple GPUs into an output buffer. This module provides an optimized implementation of allgather operations across multiple GPUs, supporting both peer-to-peer (P2P) and non-P2P communication patterns. The implementation automatically selects between approaches based on hardware capabilities: 1. P2P-based implementation (when P2P access is available): * Uses direct GPU-to-GPU memory access for better performance. * Optimized for NVLink and xGMI bandwidth utilization. * Uses vectorized memory access. 2. Non-P2P fallback implementation: * Copies data through device memory when direct GPU access isn't possible. * Simple but functional approach for systems without P2P support. ## Functions * [​`allgather`](./allgather): Performs all-gather across GPUs with variadic output.
--- ## allreduce
`allreduce[dtype: DType, rank: Int, ngpus: Int, output_lambda: Optional[elementwise_epilogue_type] = None, pdl_level: PDLLevel = PDLLevel(), *, use_multimem: Bool = False, use_quickreduce: Bool = False](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], 1 if use_multimem else ngpus], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, _max_num_blocks: Optional[Int] = None, iteration: Int = 0)` Per-device allreduce: one instance per GPU builds its own output. High-level model * Each GPU runs one instance of this function in parallel with the others. * Every instance reads all inputs but writes only its own output buffer. * A Python-level fence is inserted across the outputs to prevent reordering. Two execution paths 1. P2P fast path (when peer access is available) * 1-stage kernel (latency-bound): each thread vector-loads from all GPUs, accumulates in higher precision, and writes directly to the result. * 2-stage kernel (bandwidth-bound): reduce-scatter then all-gather. Uses each GPU's `rank_sigs[*]` payload as a staging area for partitions. Diagram (per GPU r, 2-stage): * Stage 1: write reduced partition r into payload of `rank_sigs[r]`. * Stage 2: gather partitions from all peers' payloads into `out_r`. 2. Naive fallback (no P2P) * For GPU r: create local accumulator A\_r, allocate a temporary buffer S\_r, copy each peer input into S\_r and accumulate into A\_r, then apply the epilogue into `out_r`. Diagram (per GPU r, naive): in\_r → A\_r += in\_r; for i≠r: in\_i → tmp\_r → A\_r += tmp\_r; A\_r → out\_r Notes: * Inputs must have identical shape/dtype across GPUs. * Signal buffers must be sized at least `size_of(Signal) + payload_bytes` for the P2P 2-stage path, where `payload_bytes` equals the input tensor bytecount. * The naive path is automatically selected if P2P cannot be enabled. * The `use_multimem` parameter requires P2P access between GPUs to be enabled. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Number of dimensions in the tensors. * ​ngpus ([`Int`](/mojo/std/builtin/int/Int)): Number of GPUs participating in the allreduce. * ​output\_lambda ([`Optional`](/mojo/std/collections/optional/Optional)): Elementwise epilogue applied on the device result. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): Controls PDL behavior for P2P kernels. * ​use\_multimem ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use multimem mode for improved performance. * ​use\_quickreduce ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, prefer the quickreduce 2-stage path when eligible. **Args:** * ​input\_buffers ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Inputs from ALL GPUs (for P2P, these must be peer accessible). * ​output\_buffer ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output for THIS GPU. * ​rank\_sigs ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Per-GPU Signal; header plus payload. Payload is used as scratch for the P2P 2-stage path. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for THIS GPU (device id → rank). * ​\_max\_num\_blocks ([`Optional`](/mojo/std/collections/optional/Optional)): Optional grid limit (dispatch selects a default otherwise). * ​iteration ([`Int`](/mojo/std/builtin/int/Int)): Monotonic per-call counter used to color quickreduce flags. Increment each launch; ensures barrier flags are unique across iterations to prevent reuse hazards when reusing the same signal buffers.
--- ## allreduce_2stage_quickreduce
`allreduce_2stage_quickreduce[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, output_lambda: elementwise_epilogue_type, atom_size: Int](result: NDBuffer[dtype, rank, MutAnyOrigin], local_src: UnsafePointer[Scalar[dtype], MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], num_elements: Int, my_rank: Int, iteration: Int, num_tiles_total: Int)`
--- ## allreduce_2stage_quickreduce_tile
`allreduce_2stage_quickreduce_tile[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, output_lambda: elementwise_epilogue_type, atom_size: Int, use_bufferio: Bool](result: NDBuffer[dtype, rank, MutAnyOrigin], local_src: UnsafePointer[Scalar[dtype], ImmutAnyOrigin, address_space=AddressSpace.GLOBAL if is_amd_gpu() else AddressSpace.GENERIC], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], num_elements: Int, my_rank: Int, tile: Int, num_tiles: Int, iteration: Int)`
--- ## allreduce (Allreduce)
Multi-GPU allreduce implementation for efficient tensor reduction across GPUs. This module provides an optimized implementation of allreduce operations across multiple GPUs, supporting both peer-to-peer (P2P) and non-P2P communication patterns. The implementation automatically selects between two approaches based on hardware capabilities: 1. P2P-based implementation (when P2P access is available): * Uses direct GPU-to-GPU memory access for better performance * Implements both single-stage and two-stage algorithms: * Single-stage for latency-bound transfers (small tensors) * Two-stage (reduce-scatter + all-gather) for bandwidth-bound transfers (large tensors) * Optimized for NVLink bandwidth utilization * Uses vectorized memory access and higher precision accumulation 2. Non-P2P fallback implementation: * Copies data through host memory when direct GPU access isn't possible * Simple but functional approach for systems without P2P support The implementation is tuned for common GPU architectures (A100, H100) and includes parameters that can be adjusted for different hardware configurations. ## Per-Device Architecture The allreduce operation follows a per-device execution model: 1. **Single-Device Instances**: Each GPU runs its own instance of the allreduce operation. 2. **Parallel Execution**: The Python/Graph API layer is responsible for: * Creating one allreduce op instance per participating GPU. * Ensuring all instances execute in parallel. * Ensuring correctness by staging mo.fence. 3. **Device Affinity**: Each allreduce instance: * Executes on its assigned GPU (specified via device context). * Reads from all GPUs' input buffers (requires P2P access). * Writes only to its own output buffer. * Uses the same synchronization signals as other instances. 4. **Requirements**: * Peer-to-peer access must be enabled between all participating GPUs. * All instances must launch before any can complete (for synchronization). * The device context determines which GPU executes each instance. Limitations: * Number of elements must be a multiple of SIMD width. * Maximum of 8 GPUs supported. * All input/output buffers must have identical shapes. ## Visual Overview 1. 1-Stage P2P (latency-bound) Each GPU r reads its portion from every peer buffer directly (via P2P), accumulates, then writes to its result using the epilogue: ``` GPU r (result_r) src_ptrs[0] ─┐ src_ptrs[1] ─┼──► Σ (high-precision accum) ──► output_lambda ──► result_r ... ─┘ ``` Notes: * Vectorized loads from global memory on each GPU. * Good for small/latency-bound tensors. 2. 2-Stage P2P (bandwidth-bound) Stage 1 (reduce-scatter): Each GPU r reduces its assigned partition and writes into its own signal payload (the bytes after the Signal header). ``` src_ptrs[*] ──► reduce(partition r) ──► rank_sigs[r].payload (per-GPU) ``` Stage 2 (all-gather): Each GPU r gathers all partitions from peers' payloads and writes them to its result using the epilogue. ``` [payload_0], [payload_1], ..., [payload_{ngpus-1}] ──► result_r (via output_lambda) ``` For the naive allreduce (no P2P) per-device flow and staging details, see the `_allreduce_naive_single` docstring in this file. ## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Functions * [​`allreduce`](./allreduce): Per-device allreduce: one instance per GPU builds its own output. * [​`allreduce_2stage_quickreduce`](./allreduce_2stage_quickreduce): * [​`allreduce_2stage_quickreduce_tile`](./allreduce_2stage_quickreduce_tile):
--- ## broadcast
`broadcast[dtype: DType, rank: Int, //, ngpus: Int, pdl_level: PDLLevel = PDLLevel(), use_multimem: Bool = False](input_buffer: NDBuffer[dtype, rank, ImmutAnyOrigin], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, root: Int, _max_num_blocks: Optional[Int] = None)`
--- ## broadcast_2stage
`broadcast_2stage[dtype: DType, rank: Int, //, ngpus: Int, pdl_level: PDLLevel = PDLLevel()](input_buffer: NDBuffer[dtype, rank, ImmutAnyOrigin], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, root: Int, _max_num_blocks: Optional[Int] = None)` Two-stage broadcast: scatter from root, then allgather among all GPUs. Note: This path is only used with 3+ GPUs. With 2 GPUs, broadcast uses the simpler 1-stage path for better performance. This algorithm achieves better bandwidth than simple pull broadcast by: 1. Stage 1 (Scatter): Each GPU reads 1/ngpus of the data from root and writes to its payload buffer, utilizing root's outbound NVLink bandwidth. 2. Stage 2 (Allgather): All GPUs gather from each other in parallel, with each GPU reading (ngpus-1) chunks from other GPUs' payloads. All GPUs (including root) participate uniformly in both stages, which better utilizes root's NVLink bandwidth and simplifies partitioning. IMPORTANT: Signal buffers must be sized to hold at least: size\_of(Signal) + (num\_elements / ngpus) \* size\_of(dtype) This is the payload space needed for each GPU's chunk. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data dtype of tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Number of dimensions in tensors. * ​ngpus ([`Int`](/mojo/std/builtin/int/Int)): Number of GPUs participating. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): Control PDL behavior for the kernel. **Args:** * ​input\_buffer ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Input buffer (only root's is read, but all must be valid). * ​output\_buffer ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output buffer for THIS GPU. * ​rank\_sigs ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Signal pointers with payload space for staging. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for THIS GPU. * ​root ([`Int`](/mojo/std/builtin/int/Int)): Root GPU rank (source of broadcast data). * ​\_max\_num\_blocks ([`Optional`](/mojo/std/collections/optional/Optional)): Optional maximum number of thread blocks.
--- ## broadcast_multimem_kernel
`broadcast_multimem_kernel[dtype: DType, rank: Int, BLOCK_SIZE: Int, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), pdl_level: PDLLevel = PDLLevel()](output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], input_buffer: NDBuffer[dtype, rank, ImmutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], my_rank: Int, root: Int)` Broadcast kernel using multimem.st for multicast writes. Root GPU writes to multicast address, data appears on all GPUs. Only root performs the stores; other GPUs just participate in barriers.
--- ## broadcast_pull_1stage_kernel
`broadcast_pull_1stage_kernel[dtype: DType, rank: Int, BLOCK_SIZE: Int, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), pdl_level: PDLLevel = PDLLevel()](output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], input_buffer: NDBuffer[dtype, rank, ImmutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], my_rank: Int)`
--- ## broadcast_pull_2stage_kernel
`broadcast_pull_2stage_kernel[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, pdl_level: PDLLevel = PDLLevel()](result: NDBuffer[dtype, rank, MutAnyOrigin], root_input_ptr: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], num_elements: Int, my_rank: Int, root: Int)` Two-stage broadcast: scatter from root, then allgather among all GPUs. Stage 1 (Scatter): Root's data is split into ngpus chunks. Each GPU reads its assigned chunk directly from root's input buffer and writes it to its signal payload. Non-root GPUs also write to their result buffer. Root copies all N elements from source to dest (local operation). Stage 2 (Allgather): Non-root GPUs gather the remaining chunks from all other GPUs' signal payloads (including root's). Root skips this stage since it already has all data. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data dtype of tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Number of dimensions in tensors. * ​ngpus ([`Int`](/mojo/std/builtin/int/Int)): Number of GPUs participating. * ​BLOCK\_SIZE ([`Int`](/mojo/std/builtin/int/Int)): Number of threads per block. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): Control PDL behavior for the kernel. **Args:** * ​result ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output buffer for broadcast result. * ​root\_input\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to root's input data (all GPUs read from this). * ​rank\_sigs ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Signal pointers for synchronization. IMPORTANT: Signal pointers have trailing buffers for communication. * ​num\_elements ([`Int`](/mojo/std/builtin/int/Int)): Number of elements to broadcast. * ​my\_rank ([`Int`](/mojo/std/builtin/int/Int)): Current GPU rank. * ​root ([`Int`](/mojo/std/builtin/int/Int)): Root GPU rank (source of broadcast).
--- ## broadcast (Broadcast)
Multi-GPU broadcast kernel implementation. ## Functions * [​`broadcast`](./broadcast): * [​`broadcast_2stage`](./broadcast_2stage): Two-stage broadcast: scatter from root, then allgather among all GPUs. * [​`broadcast_multimem_kernel`](./broadcast_multimem_kernel): Broadcast kernel using multimem.st for multicast writes. * [​`broadcast_pull_1stage_kernel`](./broadcast_pull_1stage_kernel): * [​`broadcast_pull_2stage_kernel`](./broadcast_pull_2stage_kernel): Two-stage broadcast: scatter from root, then allgather among all GPUs.
--- ## TuningConfigAllreduce
`@register_passable(trivial)` `struct TuningConfigAllreduce` Parameters: ngpus: Number of GPUs for running allreduce. num\_bytes: Total number of input bytes supported by the config. sm\_version: SM version (as string). num\_blocks: Number of thread blocks for running allreduce. ## Fields * ​ngpus (`Int`): * ​num\_bytes (`Int`): * ​sm\_version (`StaticString`): * ​num\_blocks (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__str__` `__str__(self) -> String` **Returns:** `String`
--- ## get_sm_version
`get_sm_version() -> StaticString` **Returns:** `StaticString`
--- ## device_query
Provides device query utilities for communication primitives. ## `comptime` values ### `allreduce_table` `comptime allreduce_table = Table[TuningConfigAllreduce](List[TuningConfigAllreduce](TuningConfigAllreduce(-1, -1, "sm_90a", 216), TuningConfigAllreduce(4, 134217728, "sm_90a", 232), TuningConfigAllreduce(-1, -1, "sm_100a", 512), TuningConfigAllreduce(2, 8388608, "sm_100a", 512), TuningConfigAllreduce(2, 16777216, "sm_100a", 512), TuningConfigAllreduce(2, 33554432, "sm_100a", 512), TuningConfigAllreduce(2, 67108864, "sm_100a", 512), TuningConfigAllreduce(2, 134217728, "sm_100a", 512), TuningConfigAllreduce(4, 8388608, "sm_100a", 512), TuningConfigAllreduce(4, 16777216, "sm_100a", 512), TuningConfigAllreduce(4, 33554432, "sm_100a", 512), TuningConfigAllreduce(4, 67108864, "sm_100a", 512), TuningConfigAllreduce(4, 134217728, "sm_100a", 512), Tuple[]()), "allreduce_table")` ## Structs * [​`TuningConfigAllreduce`](./TuningConfigAllreduce): Parameters: ngpus: Number of GPUs for running allreduce. num\_bytes: Total number of input bytes supported by the config. sm\_version: SM version (as string). num\_blocks: Number of thread blocks for running allreduce. ## Functions * [​`get_sm_version`](./get_sm_version):
--- ## comm (Comm)
Provides communication primitives for GPUs. This package includes functions for sending and receiving data between GPUs, as well as for synchronizing threads across GPUs. ## Packages * [​`vendor`](./vendor/): ## Modules * [​`allgather`](./allgather/): Multi-GPU allgather implementation that gathers values from multiple GPUs into an output buffer. * [​`allreduce`](./allreduce/): Multi-GPU allreduce implementation for efficient tensor reduction across GPUs. * [​`broadcast`](./broadcast/): Multi-GPU broadcast kernel implementation. * [​`device_query`](./device_query/): Provides device query utilities for communication primitives. * [​`reducescatter`](./reducescatter/): Multi-GPU reducescatter implementation for distributed tensor reduction across GPUs. * [​`sync`](./sync/):
--- ## ReduceScatterConfig
`@register_passable(trivial)` `struct ReduceScatterConfig[dtype: DType, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), alignment: Int = align_of[SIMD[dtype, simd_width]](), accum_type: DType = get_accum_type[dtype]()]` ## Fields * ​stride (`Int`): * ​part (`Int`): * ​remainder (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_elements: Int, threads_per_gpu: Int) -> Self` ### `rank_start` `rank_start(self, rank: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `rank_end` `rank_end(self, rank: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `rank_part` `rank_part(self, rank: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `thr_local_start` `thr_local_start(self, thread_idx: Scalar[DType.uint]) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## reducescatter
Multi-GPU reducescatter implementation for distributed tensor reduction across GPUs. ## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Structs * [​`ReduceScatterConfig`](./ReduceScatterConfig): ## Functions * [​`reducescatter`](./reducescatter): Per-device reducescatter operation.
--- ## reducescatter (Reducescatter)
`reducescatter[dtype: DType, rank: Int, ngpus: Int, output_lambda: Optional[elementwise_epilogue_type] = None, pdl_level: PDLLevel = PDLLevel(), *, use_multimem: Bool = False](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], 1 if use_multimem else ngpus], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, _max_num_blocks: Optional[Int] = None)` Per-device reducescatter operation. Performs a reduce-scatter across multiple GPUs: each GPU reduces its assigned partition from all input buffers and writes the result to its output buffer. This is equivalent to the reduce-scatter phase of the 2-stage allreduce algorithm. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data dtype of tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Number of dimensions in tensors. * ​ngpus ([`Int`](/mojo/std/builtin/int/Int)): Number of GPUs participating. * ​output\_lambda ([`Optional`](/mojo/std/collections/optional/Optional)): Optional elementwise epilogue function. If not provided, reduced values are stored directly to output\_buffer. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): Control PDL behavior for the kernel. * ​use\_multimem ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, use multimem optimization (reserved for future use). **Args:** * ​input\_buffers ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Input buffers from all GPUs (peer access required). When use\_multimem is False (default), expects ngpus buffers. When use\_multimem is True, expects a single buffer. * ​output\_buffer ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output buffer for THIS GPU's partition of reduced data. Size should be approximately 1/ngpus of the input size. * ​rank\_sigs ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Signal pointers for synchronization between GPUs. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for THIS GPU. * ​\_max\_num\_blocks ([`Optional`](/mojo/std/collections/optional/Optional)): Optional maximum number of thread blocks to launch. If not specified, uses MAX\_NUM\_BLOCKS\_UPPER\_BOUND. **Raises:** Error: If P2P access is not available between GPUs. Error: If input buffer size is not a multiple of SIMD width.
--- ## Signal
`struct Signal` A synchronization primitive for coordinating GPU thread blocks across multiple devices. This struct provides counter-based synchronization between thread blocks on different GPUs. It maintains two sets of counters: 1. self\_counter: Used by blocks on the current GPU to signal their progress 2. peer\_counter: Used to track progress of blocks on other GPUs Note: The counters use unsigned integers that may overflow, but this is safe since unsigned integer overflow has well-defined behavior. ## Fields * ​self\_counter (`StaticTuple[StaticTuple[UInt32, 8], 512]`): A 2D array of counters with shape (MAX\_NUM\_BLOCKS\_UPPER\_BOUND, MAX\_GPUS). Each counter tracks the progress of a specific thread block on the current GPU. Thread blocks increment their corresponding counter to signal completion of a phase, allowing other GPUs to detect when synchronization points are reached. The counters use atomic operations to ensure proper synchronization across devices. * ​peer\_counter (`StaticTuple[StaticTuple[StaticTuple[UInt32, 8], 512], 2]`): A 3D array of counters with shape (2, MAX\_NUM\_BLOCKS\_UPPER\_BOUND, MAX\_GPUS). Contains two sets of counters to handle two synchronization points safely. The dual counter design prevents race conditions where a peer block arrives at the second sync point before the current block passes the first sync point. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `flag_t` `comptime flag_t = DType.uint32`
--- ## can_enable_p2p
`can_enable_p2p() -> Bool` If peer-to-peer access is supported, enables it between all GPU pairs. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if P2P access is possible between all GPU pairs, False otherwise.
--- ## group_end
`group_end()`
--- ## group_start
`group_start()`
--- ## sync
## `comptime` values ### `MAX_GPUS` `comptime MAX_GPUS = 8` Maximum number of GPUs supported in the allreduce implementation. This constant sets the upper bound for the number of GPUS supported in this algorithm. ### `MAX_NUM_BLOCKS_UPPER_BOUND` `comptime MAX_NUM_BLOCKS_UPPER_BOUND = 512` Maximum number of thread blocks to use for reduction kernels. This value has been empirically optimized through grid search across different GPU architectures. While this value is optimal for A100 GPUs, H100 GPUs may benefit from more blocks to fully saturate NVLink bandwidth. ## Structs * [​`Signal`](./Signal): A synchronization primitive for coordinating GPU thread blocks across multiple devices. ## Functions * [​`can_enable_p2p`](./can_enable_p2p): If peer-to-peer access is supported, enables it between all GPU pairs. * [​`group_end`](./group_end): * [​`group_start`](./group_start):
--- ## Communicators
`struct Communicators` ## Fields * ​ngpus (`Int`): * ​comms (`InlineArray[ncclComm_t, 8]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__copyinit__` `__copyinit__(out self, rhs: Self)`
--- ## allgather (Ccl)
`allgather[dtype: DType, rank: Int, ngpus: Int](inputs: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], outputs: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], list_of_ctx: List[DeviceContext])`
--- ## allreduce (Ccl)
`allreduce[dtype: DType, rank: Int, ngpus: Int, output_lambda: Optional[elementwise_epilogue_type] = None, pdl_level: PDLLevel = PDLLevel(), *, use_multimem: Bool = False, use_quickreduce: Bool = False](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], 1 if use_multimem else ngpus], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, _max_num_blocks: Optional[Int] = None, iteration: Int = 0)` Per-GPU allreduce for use in multi-threaded contexts. Currently requires prior single-threaded call to init\_comms, as thread-safe version not yet implemented.
--- ## broadcast (Ccl)
`broadcast[dtype: DType, rank: Int, //, ngpus: Int, pdl_level: PDLLevel = PDLLevel(), use_multimem: Bool = False](input_buffer: NDBuffer[dtype, rank, ImmutAnyOrigin], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, root: Int, _max_num_blocks: Optional[Int] = None)` Per-GPU broadcast for use in multi-threaded contexts. Currently requires prior single-threaded call to init\_comms, as thread-safe version not yet implemented.
--- ## group
`group() -> _Group` **Returns:** `_Group`
--- ## ccl
## `comptime` values ### `CCL_LIBRARY` `comptime CCL_LIBRARY = _Global["CCL_LIBRARY", _init_ccl_dylib]` ### `CCLAllGatherFn` `comptime CCLAllGatherFn = fn(LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType], Int, ncclDataType_t, LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType]) -> ncclResult_t` ### `CCLAllReduceFn` `comptime CCLAllReduceFn = fn(LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType], Int, ncclDataType_t, ncclRedOp_t, LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType]) -> ncclResult_t` ### `CCLBroadcastFn` `comptime CCLBroadcastFn = fn(LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType], Int, ncclDataType_t, Int, LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType]) -> ncclResult_t` ### `NCCL_LIBRARY_PATHS` `comptime NCCL_LIBRARY_PATHS = List[Path]("libnccl.so", "libnccl.so.2", "/usr/lib/x86_64-linux-gnu/libnccl.so", "/usr/lib/x86_64-linux-gnu/libnccl.so.2", Tuple[]())` ### `ncclComm_t` `comptime ncclComm_t = OpaquePointer` ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` ### `RCCL_LIBRARY_PATHS` `comptime RCCL_LIBRARY_PATHS = List[Path]("librccl.so", "librccl.so.1", "/opt/rocm/lib/librccl.so", "/opt/rocm/lib/librccl.so.1", Tuple[]())` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Communicators`](./Communicators): * [​`ncclDataType_t`](./ncclDataType_t): * [​`ncclRedOp_t`](./ncclRedOp_t): * [​`ncclResult_t`](./ncclResult_t): ## Functions * [​`allgather`](./allgather): * [​`allreduce`](./allreduce): Per-GPU allreduce for use in multi-threaded contexts. * [​`broadcast`](./broadcast): Per-GPU broadcast for use in multi-threaded contexts. * [​`group`](./group): * [​`init_comms`](./init_comms): Pre-initialize NCCL/RCCL communicators. * [​`is_allgather_available`](./is_allgather_available): * [​`is_allreduce_available`](./is_allreduce_available): * [​`is_broadcast_available`](./is_broadcast_available): * [​`ncclCommInitAll`](./ncclCommInitAll):
--- ## init_comms
`init_comms(ngpus: Int)` Pre-initialize NCCL/RCCL communicators. Must be called from a single thread before using allreduce from multiple threads. This ensures thread-safe initialization since ncclCommInitAll is not designed for concurrent calls.
--- ## is_allgather_available
`is_allgather_available() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## is_allreduce_available
`is_allreduce_available() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## is_broadcast_available
`is_broadcast_available() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## ncclCommInitAll
`ncclCommInitAll(comms: LegacyUnsafePointer[ncclComm_t], ndev: Int, devlist: LegacyUnsafePointer[Int32]) -> ncclResult_t` **Returns:** `ncclResult_t`
--- ## ncclDataType_t
`@register_passable(trivial)` `struct ncclDataType_t` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclBfloat16` `comptime ncclBfloat16 = ncclDataType_t(9)` ### `ncclFloat16` `comptime ncclFloat16 = ncclDataType_t(6)` ### `ncclFloat32` `comptime ncclFloat32 = ncclDataType_t(7)` ## Methods ### `__init__` `__init__(value: Int) -> Self`
--- ## ncclRedOp_t
`@register_passable(trivial)` `struct ncclRedOp_t` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclSum` `comptime ncclSum = ncclRedOp_t(0)` ## Methods ### `__init__` `__init__(value: Int) -> Self`
--- ## ncclResult_t
`@register_passable(trivial)` `struct ncclResult_t` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclSuccess` `comptime ncclSuccess = ncclResult_t(0)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `write_to` `write_to(self, mut writer: T)`
--- ## vendor
## Modules * [​`ccl`](./ccl/):
--- ## extensibility
Includes the tensor package. ## Packages * [​`tensor`](./tensor/): APIs to create and manage tensors in a graph.
--- ## tensor (Tensor)
APIs to create and manage tensors in a graph. ## Modules * [​`io_spec`](./io_spec/): * [​`managed_tensor_slice`](./managed_tensor_slice/): Implements the `ManagedTensorSlice` type - a view of a tensor that doesn't own the underlying data. This type is used to build custom graph operations. * [​`operation_traits`](./operation_traits/): * [​`tensor_spec`](./tensor_spec/): You can import these APIs from the `max.tensor` package. For example: * [​`transitional`](./transitional/): Utilities for transitional period during NDBuffer deprecation.
--- ## IO
`@register_passable(trivial)` `struct IO` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `FusedInput` `comptime FusedInput = IO(2)` ### `FusedOutput` `comptime FusedOutput = IO(3)` ### `Input` `comptime Input = IO(1)` ### `Output` `comptime Output = IO(0)` ### `Unknown` `comptime Unknown = IO(-1)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## IOSpec
`@register_passable(trivial)` `struct IOSpec[mut: Bool, input: IO]` Parameter used to encode whether a particular tensor argument to a DPS kernel is an output, input, or mutable input. ```mojo Input == IOSpec[False, IO.Input]() Output == IOSpec[True, IO.Output]() MutableInput == IOSpec[True, IO.Input]() FusedInput == IOSpec[False, IO.FusedInput]() FusedOutput == IOSpec[True, IO.FusedOutput]() ``` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True`
--- ## io_spec
## `comptime` values ### `FusedInput` `comptime FusedInput = IOSpec[False, IO.FusedInput]()` ### `FusedOutput` `comptime FusedOutput = IOSpec[True, IO.FusedOutput]()` ### `Input` `comptime Input = IOSpec[False, IO.Input]()` ### `IOUnknown` `comptime IOUnknown = IOSpec[True, IO.Unknown]()` ### `MutableInput` `comptime MutableInput = IOSpec[True, IO.Input]()` ### `Output` `comptime Output = IOSpec[True, IO.Output]()` ## Structs * [​`IO`](./IO): * [​`IOSpec`](./IOSpec): Parameter used to encode whether a particular tensor argument to a DPS kernel is an output, input, or mutable input.
--- ## ManagedTensorSlice
`@register_passable(trivial)` `struct ManagedTensorSlice[mut: Bool, input: IO, dtype: DType, rank: Int, //, io_spec: IOSpec[mut, input], *, static_spec: StaticTensorSpec[dtype, rank]]` A view of a tensor that does not own the underlying allocated pointer. When the object lifetime ends it does not free the underlying pointer. Conversely, if a `ManagedTensorSlice` is created, it will not extend the life of the underlying pointer. Therefore, the user must take care to keep the pointer alive until the last use of a `ManagedTensorSlice` instance. This class is useful for writing custom operations where memory is managed by an external runtime like in MAX's inference stack. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `address_space` `comptime address_space = static_spec.address_space` ### `alignment` `comptime alignment = static_spec.alignment` ### `device_type` `comptime device_type = LayoutTensor[dtype, static_spec.to_layout[dtype, rank](), MutAnyOrigin]` ### `exclusive` `comptime exclusive = static_spec.exclusive` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], slices: InlineArray[Slice, rank], slicer_spec: RuntimeTensorSpec[dtype, rank]) -> Self` Initializes a ManagedTensorSlice from a pointer, array of slices and tensor spec. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], shape: IndexList[rank]) -> Self` Initializes a ManagedTensorSlice from a pointer and shape. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], shape: IndexList[rank], strides: IndexList[rank]) -> Self` Initializes a ManagedTensorSlice from a pointer, shape, and strides. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. ### `__getitem__` `__getitem__(self, indices: IndexList[rank]) -> Scalar[dtype]` Gets the value at the specified indices. **Args:** * ​indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): The indices of the value to retrieve. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The value at the specified indices. `__getitem__(self, *indices: Int) -> Scalar[dtype]` Gets the value at the specified indices. **Args:** * ​\*indices ([`Int`](/mojo/std/builtin/int/Int)): The indices of the value to retrieve. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The value at the specified indices. ### `__setitem__` `__setitem__(self, *indices: Int, *, val: Scalar[dtype])` Stores the value at the specified indices. **Args:** * ​\*indices ([`Int`](/mojo/std/builtin/int/Int)): The indices of the value to store. * ​val ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The value to store. `__setitem__(self, indices: IndexList[rank], val: Scalar[dtype])` Stores the value at the specified indices. **Args:** * ​indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): The indices of the value to store. * ​val ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The value to store. ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `spec` `spec(self) -> RuntimeTensorSpec[dtype, rank]` Gets the `TensorSpec` of this tensor slice, which provides meta-data about the tensor slice. **Returns:** [`RuntimeTensorSpec`](/mojo/tensor/tensor_spec/RuntimeTensorSpec): The static `TensorSpec` for this tensor slice. ### `shape` `shape(self) -> IndexList[rank]` Gets the shape of this tensor slice, as an `IndexList`. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The shape of this tensor slice. ### `dim_size` `dim_size(self, index: Int) -> Int` Gets the size of a given dimension of this tensor slice using a run time value. **Args:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor slice in the given dimension. `dim_size[index: Int](self) -> Int` Gets the size of a given dimension of this tensor slice using a compile time value. **Parameters:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor slice in the given dimension. ### `strides` `strides(self) -> IndexList[rank]` Gets the strides of this tensor slice, as an `IndexList`. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The strides of this tensor slice. ### `stride_length` `stride_length(self, index: Int) -> Int` Gets the length of the stride of a given dimension of this tensor slice using a run time value. **Args:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor slice in the given dimension. `stride_length[index: Int](self) -> Int` Gets the length of the stride of a given dimension of this tensor slice using a compile time value. **Parameters:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor slice in the given dimension. ### `size` `size(self) -> Int` Computes the tensor slice's number of elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total number of elements in the tensor slice. ### `unsafe_ptr` `unsafe_ptr[_dtype: DType = dtype](self) -> LegacyUnsafePointer[Scalar[_dtype]]` Get the pointer stored in this tensor slice. Since this method obtains the pointer stored in this tensor slice, it can modify the invariants of this tensor slice and lead to unexpected behavior. It should be used with caution. **Parameters:** * ​\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The type of the `UnsafePointer` in this tensor slice. **Returns:** [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer): The `UnsafePointer` which contains the data for this tensor slice. ### `load` `load[width: Int, _rank: Int, element_alignment: Int = 1](self, index: IndexList[_rank]) -> SIMD[dtype, width]` Gets data from this tensor slice as a `SIMD`. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The width of the `SIMD` value. This must be large enough to contain the data from this tensor slice. * ​\_rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor slice. * ​element\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Indicate the alignment of the pointer stored to memory. This is needed to issue vector load for GPUs with strict alignment requirements. **Args:** * ​index ([`IndexList`](/mojo/std/utils/index_/IndexList)): An `IndexList` of size `_rank` to indicate the dimension of the tensor slice to obtain data from. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): Data from this tensor slice at dimension `index`. ### `store` `store[width: Int, _rank: Int, element_alignment: Int = 1](self: ManagedTensorSlice[io_spec, static_spec=static_spec], index: IndexList[_rank], val: SIMD[dtype, width])` Sets data in this tensor slice from a `SIMD`. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The width of the `SIMD` value. * ​\_rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor slice. * ​element\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Indicate the alignment of the pointer stored to memory. This is needed to issue vector store for GPUs with strict alignment requirements. **Args:** * ​index ([`IndexList`](/mojo/std/utils/index_/IndexList)): An `IndexList` of size `_rank` to indicate the dimension of the tensor slice to set data in. * ​val ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The data to set into this tensor slice. ### `with_layout` `with_layout[new_rank: Int, //, new_static_shape: DimList, new_static_strides: DimList](self, new_runtime_shape: IndexList[new_rank], new_runtime_strides: IndexList[new_rank], offset_ptr: Optional[LegacyUnsafePointer[Scalar[dtype]]] = None) -> ManagedTensorSlice[io_spec, static_spec=static_spec.with_layout[dtype, rank, new_rank](new_static_shape, new_static_strides)]` **Returns:** [`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice) ### `to_layout_tensor` `to_layout_tensor(self) -> LayoutTensor[dtype, static_spec.to_layout[dtype, rank](), MutAnyOrigin]` **Returns:** `LayoutTensor` ### `to_tile_tensor` `to_tile_tensor[coord_dtype: DType](self) -> TileTensor[dtype, Layout[#kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic>, #lit.struct.extract<:!lit.struct<_std::_builtin::_variadics::_VariadicList<:trait<_std::_builtin::_value::_TrivialRegisterPassable> _buffer::_dimlist::_Dim>> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_DimList> #lit.struct.extract<:!lit.struct<_compiler_internal::_directives::_StaticTensorSpec<:!lit.struct<_std::_builtin::_dtype::_DType> dtype, :!lit.struct<_std::_builtin::_int::_Int> rank>> static_spec, "shape">, "value">, "value"> : !kgen.variadic>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), cond(sugar_preserved(#lit.struct.extract<:!lit.struct<_std::_builtin::_bool::_Bool> apply(:!lit.generator<("self": !lit.struct<_std::_builtin::_int::_Int>, "rhs": !lit.struct<_std::_builtin::_int::_Int>) -> !lit.struct<_std::_builtin::_bool::_Bool>> _std::_builtin::_int::_Int::_"__ne__(::Int,::Int)", #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">, {-31337}), "_mlir_value">, ne(#lit.struct.extract<:!lit.struct<_std::_builtin::_int::_Int> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">, "_mlir_value">, -31337)), [_layout::__coord::_ComptimeInt<:!lit.struct<_std::_builtin::_int::_Int> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">>], [_layout::__coord::_RuntimeInt<:!lit.struct<_std::_builtin::_dtype::_DType> coord_dtype>])> : !kgen.variadic>>>> : !kgen.generator>, "VA": variadic>, "idx": index>variadic>>>>, #kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic>, #lit.struct.extract<:!lit.struct<_std::_builtin::_variadics::_VariadicList<:trait<_std::_builtin::_value::_TrivialRegisterPassable> _buffer::_dimlist::_Dim>> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_DimList> #lit.struct.extract<:!lit.struct<_compiler_internal::_directives::_StaticTensorSpec<:!lit.struct<_std::_builtin::_dtype::_DType> dtype, :!lit.struct<_std::_builtin::_int::_Int> rank>> static_spec, "strides">, "value">, "value"> : !kgen.variadic>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), cond(sugar_preserved(#lit.struct.extract<:!lit.struct<_std::_builtin::_bool::_Bool> apply(:!lit.generator<("self": !lit.struct<_std::_builtin::_int::_Int>, "rhs": !lit.struct<_std::_builtin::_int::_Int>) -> !lit.struct<_std::_builtin::_bool::_Bool>> _std::_builtin::_int::_Int::_"__ne__(::Int,::Int)", #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">, {-31337}), "_mlir_value">, ne(#lit.struct.extract<:!lit.struct<_std::_builtin::_int::_Int> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">, "_mlir_value">, -31337)), [_layout::__coord::_ComptimeInt<:!lit.struct<_std::_builtin::_int::_Int> #lit.struct.extract<:!lit.struct<_buffer::_dimlist::_Dim> variadic_get(:variadic> *(0,1), *(0,2)), "_value_or_missing">>], [_layout::__coord::_RuntimeInt<:!lit.struct<_std::_builtin::_dtype::_DType> coord_dtype>])> : !kgen.variadic>>>> : !kgen.generator>, "VA": variadic>, "idx": index>variadic>>>>], MutExternalOrigin]` **Returns:** `TileTensor` ### `write_to` `write_to(self, mut writer: T)` Formats this buffer to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Gets the buffer as a string. **Returns:** [`String`](/mojo/std/collections/string/string/String): A compact string representation of the buffer. ### `__str__` `__str__(self) -> String` Gets the buffer as a string. **Returns:** [`String`](/mojo/std/collections/string/string/String): A compact string of the buffer.
--- ## VariadicTensors
`@register_passable(trivial)` `struct VariadicTensors[mut: Bool, input: IO, //, dtype: DType, rank: Int, size: Int, io_spec: IOSpec[mut, input], *, static_specs: StaticTuple[StaticTensorSpec[dtype, rank], size]]` A tuple-like container of tensors representing variadic arguments from the graph compiler. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Sized`](/mojo/std/builtin/len/Sized), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(ptrs: StaticTuple[LegacyUnsafePointer[Scalar[dtype]], size], shapes: StaticTuple[IndexList[rank], size]) -> Self` Initialize the variadic tensor from tuples of pointers and shapes. This is a bulk initialization of the VariadicTensors value from an array of pointers and an array of runtime shapes. This allows the graph compiler to avoid generating code to construct DynamicTensor values directly. ### `__getitem__` `__getitem__[index: Int](self) -> ManagedTensorSlice[io_spec, static_spec=static_specs.__getitem__[StaticTensorSpec[dtype, rank], size](index)]` Returns the tensor at the given position in the variadic argument argument pack. **Parameters:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The index into the variadic tensor arguments. **Returns:** [`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice): The tensor at the specified index. ### `__len__` `__len__(self) -> Int` Returns the number of variadic arguments in the pack. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of variadic arguments.
--- ## foreach
`foreach[dtype: DType, rank: Int, //, func: fn[width: Int, element_alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width], *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each", use_blocking_impl: Bool = False](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int, element_alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. * ​use\_blocking\_impl ([`Bool`](/mojo/std/builtin/bool/Bool)): If the impl should use this thread for doing the work. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The output tensor slice which receives the return values from `func`. * ​ctx ([`DeviceContextPtr`](/mojo/std/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation). `foreach[dtype: DType, rank: Int, //, func: fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width], out_func: fn[width: Int](IndexList[rank]) capturing -> None, *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each", use_blocking_impl: Bool = False](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​out\_func (`fn[width: Int](IndexList[rank]) capturing -> None`): The function to apply on each output element. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. * ​use\_blocking\_impl ([`Bool`](/mojo/std/builtin/bool/Bool)): If the impl should use this thread for doing the work. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The input tensor slice which the consumed values. * ​ctx ([`DeviceContextPtr`](/mojo/std/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation). `foreach[dtype: DType, rank: Int, //, func: fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width], *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each", use_blocking_impl: Bool = False](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. * ​use\_blocking\_impl ([`Bool`](/mojo/std/builtin/bool/Bool)): If the impl should use this thread for doing the work. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The output tensor slice which receives the return values from `func`. * ​ctx ([`DeviceContextPtr`](/mojo/std/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation).
--- ## managed_tensor_slice
Implements the `ManagedTensorSlice` type - a view of a tensor that doesn't own the underlying data. This type is used to build custom graph operations. ## `comptime` values ### `DynamicTensor` `comptime DynamicTensor[dtype: DType, rank: Int] = ManagedTensorSlice[IOUnknown, static_spec=StaticTensorSpec.create_unknown[dtype, rank]()]` #### Parameters * ​dtype (`DType`): * ​rank ([`Int`](/std/builtin/int/Int)): ### `InputTensor` `comptime InputTensor = ManagedTensorSlice[Input, static_spec=?]` ### `InputVariadicTensors` `comptime InputVariadicTensors = VariadicTensors[?, ?, ?, Input, static_specs=?]` ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` ### `OutputTensor` `comptime OutputTensor = ManagedTensorSlice[Output, static_spec=?]` ### `OutputVariadicTensors` `comptime OutputVariadicTensors = VariadicTensors[?, ?, ?, Output, static_specs=?]` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`ManagedTensorSlice`](./ManagedTensorSlice): A view of a tensor that does not own the underlying allocated pointer. When the object lifetime ends it does not free the underlying pointer. Conversely, if a `ManagedTensorSlice` is created, it will not extend the life of the underlying pointer. * [​`VariadicTensors`](./VariadicTensors): A tuple-like container of tensors representing variadic arguments from the graph compiler. ## Functions * [​`foreach`](./foreach): Apply the function `func` to each element of the tensor slice. * [​`rebuild_mix_precision_static_tensor_specs_with_input_lambda`](./rebuild_mix_precision_static_tensor_specs_with_input_lambda): * [​`rebuild_static_tensor_specs_with_compute_output_lambda`](./rebuild_static_tensor_specs_with_compute_output_lambda): * [​`rebuild_static_tensor_specs_with_input_lambda`](./rebuild_static_tensor_specs_with_input_lambda): * [​`rebuild_static_tensor_specs_with_output_lambda`](./rebuild_static_tensor_specs_with_output_lambda): * [​`trace_slice_arg`](./trace_slice_arg): Helper to stringify the type and shape of a kernel argument for tracing.
--- ## rebuild_mix_precision_static_tensor_specs_with_input_lambda
`rebuild_mix_precision_static_tensor_specs_with_input_lambda[func_type: __TypeOfAllTypes, //, src_dtype: DType, dst_dtype: DType, rank: Int](spec: StaticTensorSpec[src_dtype, rank], in_lambda: func_type) -> StaticTensorSpec[dst_dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_compute_output_lambda
`rebuild_static_tensor_specs_with_compute_output_lambda[func_type: __TypeOfAllTypes, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], out_compute_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_input_lambda
`rebuild_static_tensor_specs_with_input_lambda[func_type: __TypeOfAllTypes, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], in_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_output_lambda
`rebuild_static_tensor_specs_with_output_lambda[func_type: __TypeOfAllTypes, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], out_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## trace_slice_arg
`trace_slice_arg(name: String, buf: ManagedTensorSlice[io_spec, static_spec=static_spec]) -> String` Helper to stringify the type and shape of a kernel argument for tracing. **Args:** * ​name ([`String`](/mojo/std/collections/string/string/String)): The name of the argument. * ​buf ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The tensor to trace. **Returns:** [`String`](/mojo/std/collections/string/string/String): A string representation of the buffer with its shape and data type.
--- ## ElementwiseBinaryComparisonOp
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[DType.bool, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## ElementwiseBinaryOp
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## ElementwiseUnaryMixedOp
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## Required methods ### `elementwise` `static elementwise[dtype: DType, out_dtype: DType, width: Int](x: SIMD[dtype, width]) -> SIMD[out_dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## ElementwiseUnaryOp
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](x: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## operation_traits
## Traits * [​`ElementwiseBinaryComparisonOp`](./ElementwiseBinaryComparisonOp): * [​`ElementwiseBinaryOp`](./ElementwiseBinaryOp): * [​`ElementwiseUnaryMixedOp`](./ElementwiseUnaryMixedOp): * [​`ElementwiseUnaryOp`](./ElementwiseUnaryOp):
--- ## RuntimeTensorSpec
`@register_passable(trivial)` `struct RuntimeTensorSpec[dtype: DType, rank: Int]` ## Fields * ​shape (`IndexList[rank]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__getitem__` `__getitem__(self, idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `bytecount` `bytecount(self) -> Int` Gets the total byte count. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total byte count.
--- ## tensor_spec
You can import these APIs from the `max.tensor` package. For example: ```mojo from max.tensor import RuntimeTensorSpec ``` ## Structs * [​`RuntimeTensorSpec`](./RuntimeTensorSpec):
--- ## transitional
Utilities for transitional period during NDBuffer deprecation. ## Functions * [​`managed_tensor_slice_to_ndbuffer`](./managed_tensor_slice_to_ndbuffer):
--- ## managed_tensor_slice_to_ndbuffer
`managed_tensor_slice_to_ndbuffer[spec: StaticTensorSpec[dtype, rank], //](tensor: ManagedTensorSlice[io_spec, static_spec=spec]) -> NDBuffer[dtype, rank, MutAnyOrigin, spec.shape, spec.strides, address_space=spec.address_space]` **Returns:** `NDBuffer`
--- ## kv_cache (3)
Contains implementations for several types of key-value caches. [KV caches](/glossary/ai/kv-cache) are used in transformer models to store key-value tensors output from self-attention layers. These APIs are used in the higher-level functions in the [`nn`](/mojo/kernels/nn) package. ## Modules * [​`types`](./types/): This module contains the types for the key-value cache APIs.
--- ## ContinuousBatchingKVCache
`@register_passable(trivial)` `struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams]` Wrapper for the ContinuousKVCache of a given layer in the transformer model. This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry. THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS. ## Parameters * ​dtype\_ ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. ## Fields * ​blocks (`ContinuousBatchingKVCache[dtype_, kv_params_].blocks_type`): * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVCacheT`](/mojo/kernels/kv_cache/types/KVCacheT), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, -1, Int.__init__[UInt](ContinuousBatchingKVCache[dtype_, kv_params_].kv_params.num_heads), Int.__init__[UInt](ContinuousBatchingKVCache[dtype_, kv_params_].kv_params.head_size))` ### `blocks_type` `comptime blocks_type = LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin]` ### `device_type` `comptime device_type = ContinuousBatchingKVCache[dtype_, kv_params_]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `page_size_` `comptime page_size_ = 0` ### `quantization_enabled` `comptime quantization_enabled = False` ### `scale_dtype` `comptime scale_dtype = DType.float32` ## Methods ### `__init__` `__init__(blocks: LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** `String` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `cache_lengths_nd` `cache_lengths_nd(self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `load` `load[width: Int, output_dtype: DType = ContinuousBatchingKVCache[dtype_, kv_params_].dtype](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store` `store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, size])` ### `load_scale` `load_scale[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[DType.float32, width]` Loads a quantization scale from the given index. Note: ContinuousBatchingKVCache does not support KVCache quantization. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store_scale` `store_scale(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[DType.float32, size])` Stores the quantization scales at the given index. Note: ContinuousBatchingKVCache does not support KVCache quantization. ### `load_quantized` `load_quantized[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, width]` Loads a quantized element from the given index. Note: ContinuousBatchingKVCache does not support KVCache quantization. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `empty_cache` `empty_cache(self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `row_idx` `row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, swizzle_mode, Int.__init__[UInt](ContinuousBatchingKVCache[dtype_, kv_params_].kv_params.head_size)]()](self, ctx: DeviceContext) -> TMATensorTile[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, _split_last_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. **Returns:** `TMATensorTile` ### `create_ragged_tma_tile` `create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, swizzle_mode, Int.__init__[UInt](ContinuousBatchingKVCache[dtype_, kv_params_].kv_params.head_size)]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, swizzle_mode, BN, BK])` **Returns:** [`RaggedTMA3DTile`](/mojo/kernels/layout/tma_async/RaggedTMA3DTile) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[ContinuousBatchingKVCache[dtype_, kv_params_].dtype], MutAnyOrigin]` **Returns:** `UnsafePointer` ### `scales_block_paged_ptr` `scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Float32, MutAnyOrigin]` Returns a pointer to the scales block at the requested indices. Note: ContinuousBatchingKVCache does not support KVCache quantization. This function returns a NULL pointer. **Returns:** `UnsafePointer`
--- ## ContinuousBatchingKVCacheCollection
`struct ContinuousBatchingKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams]` This is a "view" of the cache for the given sequences in the batch. This object does not own the underlying buffers in k\_cache and v\_cache, it's borrowing them from the BlockWrappers in our KVCacheManager. It does own the Pointer\[LayoutTensor\[dtype, Layout.row\_major[3]()]] and valid\_lengths buffer ## Parameters * ​dtype\_ ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. ## Fields * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​blocks (`ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_type`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): * ​kv\_cache\_dynamic\_shape (`IndexList[4]`): * ​kv\_cache\_dynamic\_strides (`IndexList[4]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVCollectionT`](/mojo/kernels/kv_cache/types/KVCollectionT), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, -1, -1, -1, Int.__init__[UInt](ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params.num_heads), Int.__init__[UInt](ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params.head_size))` ### `blocks_type` `comptime blocks_type = LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout, MutAnyOrigin]` ### `CacheType` `comptime CacheType = ContinuousBatchingKVCache[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `name_str` `comptime name_str = "continuous_batching"` ### `scale_dtype` `comptime scale_dtype = DType.invalid` ## Methods ### `__init__` `__init__(out self, blocks: LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[DType.invalid, Layout.row_major[6](), MutAnyOrigin]] = None)` ### `get_key_cache` `get_key_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType` **Returns:** `ContinuousBatchingKVCacheCollection` ### `get_value_cache` `get_value_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType` **Returns:** `ContinuousBatchingKVCacheCollection` ### `cache_length` `cache_length(self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## KVCacheStaticParams
`@register_passable(trivial)` `struct KVCacheStaticParams` ## Fields * ​num\_heads (`UInt`): * ​head\_size (`UInt`): * ​is\_mla (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_heads: Scalar[DType.uint], head_size: Scalar[DType.uint], is_mla: Bool = False) -> Self` Initialize KVCacheStaticParams. Args: num\_heads (UInt): Number of attention heads. head\_size (UInt): Size of each attention head. is\_mla (Bool, optional): Whether to use Multi-Linear Attention (MLA) mode. If true, we only store k cache. If False, we store k and v cache. Defaults to False.
--- ## KVCacheT
Trait for different KVCache types and implementations. Represents a single (key or value) cache. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `dtype` `comptime dtype` ### `kv_params` `comptime kv_params` ### `page_size_` `comptime page_size_` ### `quantization_enabled` `comptime quantization_enabled = False` ### `scale_dtype` `comptime scale_dtype = DType.invalid` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `cache_lengths_nd` `cache_lengths_nd(self: _Self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` Returns the cache lengths as a LayoutTensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self: _Self, batch_idx: Int) -> Int` Returns the length of the cache for a given batch index. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `load` `load[width: Int, output_dtype: DType = _Self.dtype](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]` Loads an element from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store` `store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype, size])` Stores an element at the given index. ### `store_scale` `store_scale(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[_Self.scale_dtype, size])` Stores the quantization scales at the given index. ### `load_scale` `load_scale[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]` Loads the quantization scales from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `load_quantized` `load_quantized[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]` Loads a quantized element from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `empty_cache` `empty_cache(self: _Self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self: _Self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self: _Self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]` Returns a LayoutTensor pointing to the KVCache block at the given index. Paged KVCache implementations must have a block\_size which is a multiple of the and greater than the layout's first dimension. **Returns:** `UnsafePointer` ### `scales_block_paged_ptr` `scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]` Returns a pointer to the scales block at the requested indices. **Returns:** `UnsafePointer` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `row_idx` `row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int.__init__[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, _split_last_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. This is useful for `k-major` MMA operations where we don't need to mask any extra rows. **Returns:** `TMATensorTile` ### `create_ragged_tma_tile` `create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int.__init__[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]` Creates a TMA tile for this KV cache. This is useful for `mn-major` MMA operations where we need to mask extra rows to avoid adding `NaN` to the output through the MMA reduction. **Returns:** [`RaggedTMA3DTile`](/mojo/kernels/layout/tma_async/RaggedTMA3DTile) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The host type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## KVCollectionT
Trait for a pair of caches (keys and values). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `CacheType` `comptime CacheType` ### `dtype` `comptime dtype` ### `kv_params` `comptime kv_params` ### `name_str` `comptime name_str` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `get_key_cache` `get_key_cache(self: _Self, layer_idx: Int) -> _Self.CacheType` **Returns:** `_Self.CacheType` ### `get_value_cache` `get_value_cache(self: _Self, layer_idx: Int) -> _Self.CacheType` **Returns:** `_Self.CacheType` ### `cache_length` `cache_length(self: _Self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## PagedKVCache
`@register_passable(trivial)` `struct PagedKVCache[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, scale_dtype_: DType = DType.invalid, quantization_granularity: Int = 1]` The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention. Note: This struct represents a 4D view of a 6D `PagedKVCacheCollection` tensor. The compile-time layout has `UNKNOWN_VALUE` for stride\[0] because the actual stride depends on `num_layers` from the parent tensor, which is only known at runtime. This ensures offset calculations use the correct runtime strides rather than incorrect compile-time values. ## Parameters * ​dtype\_ ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. * ​page\_size ([`Int`](/mojo/std/builtin/int/Int)): The size of the page. * ​scale\_dtype\_ ([`DType`](/mojo/std/builtin/dtype/DType)): Dtype of the quantization scales (if quantization enabled). * ​quantization\_granularity ([`Int`](/mojo/std/builtin/int/Int)): Block size used for quantization (e.g. 128). ## Fields * ​blocks (`PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].blocks_type`): * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin]`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): * ​scales (`OptionalReg[LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scales_layout, MutAnyOrigin]]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVCacheT`](/mojo/kernels/kv_cache/types/KVCacheT), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout(PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].blocks_shape, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].blocks_strides)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, page_size, Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.num_heads), Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size))` ### `blocks_strides` `comptime blocks_strides = IntTuple(-1, (Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.num_heads) * Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size)), Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size), 1)` ### `blocks_type` `comptime blocks_type = LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].blocks_layout, MutAnyOrigin]` ### `device_type` `comptime device_type = PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity]` ### `dtype` `comptime dtype = dtype_` ### `head_dim_granularity` `comptime head_dim_granularity = ceildiv(Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size), quantization_granularity)` ### `kv_params` `comptime kv_params = kv_params_` ### `page_size_` `comptime page_size_ = page_size` ### `quantization_enabled` `comptime quantization_enabled = (scale_dtype_ != DType.invalid)` ### `scale_dtype` `comptime scale_dtype = scale_dtype_` ### `scales_block_type` `comptime scales_block_type = LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scales_layout, MutAnyOrigin]` ### `scales_layout` `comptime scales_layout = Layout.row_major(PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scales_shape)` ### `scales_shape` `comptime scales_shape = IntTuple(-1, page_size, Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.num_heads), PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].head_dim_granularity)` ## Methods ### `__init__` `__init__(blocks: LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].blocks_layout, MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scales_layout, MutAnyOrigin]] = None) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** `String` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `cache_lengths_nd` `cache_lengths_nd(self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` Returns the length of the cache for a given batch index. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `row_idx` `row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, swizzle_mode, Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size)]()](self, ctx: DeviceContext) -> TMATensorTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, _split_last_layout[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. **Returns:** `TMATensorTile` ### `create_ragged_tma_tile` `create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, swizzle_mode, Int.__init__[UInt](PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].kv_params.head_size)]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, swizzle_mode, BN, BK])` **Returns:** [`RaggedTMA3DTile`](/mojo/kernels/layout/tma_async/RaggedTMA3DTile) ### `load` `load[width: Int, output_dtype: DType = PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]` Loads an element from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store` `store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, size])` Stores an element at the given index. ### `load_scale` `load_scale[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype, width]` Loads a quantization scale from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store_scale` `store_scale(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype, size])` Stores the quantization scales at the given index. ### `load_quantized` `load_quantized[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype, width]` Loads a quantized element from the given index. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `empty_cache` `empty_cache(self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].dtype], MutAnyOrigin]` **Returns:** `UnsafePointer` ### `scales_block_paged_ptr` `scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity].scale_dtype], MutAnyOrigin]` Returns a pointer to the scales block at the requested indices. **Returns:** `UnsafePointer`
--- ## PagedKVCacheCollection
`struct PagedKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, scale_dtype_: DType = DType.invalid]` ## Fields * ​scales (`OptionalReg[LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scale_dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scales_layout, MutAnyOrigin]]`): * ​kv\_cache\_scales\_dynamic\_shape (`IndexList[4]`): * ​kv\_cache\_scales\_dynamic\_strides (`IndexList[4]`): * ​blocks (`PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].blocks_type`): * ​cache\_lengths (`PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].cache_lengths_type`): * ​lookup\_table (`PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].lookup_table_type`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): * ​kv\_cache\_dynamic\_shape (`IndexList[4]`): * ​kv\_cache\_dynamic\_strides (`IndexList[4]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVCollectionT`](/mojo/kernels/kv_cache/types/KVCollectionT), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, 2 if PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.is_mla.__bool__().__invert__()._mlir_value else 1, -1, page_size, Int.__init__[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.num_heads), Int.__init__[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.head_size))` ### `blocks_type` `comptime blocks_type = LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].blocks_layout, MutAnyOrigin]` ### `cache_lengths_type` `comptime cache_lengths_type = LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` ### `CacheType` `comptime CacheType = PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scale_dtype]` ### `dtype` `comptime dtype = dtype_` ### `head_dim_granularity` `comptime head_dim_granularity = ceildiv(Int.__init__[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.head_size), 1)` ### `kv_params` `comptime kv_params = kv_params_` ### `lookup_table_type` `comptime lookup_table_type = LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin]` ### `name_str` `comptime name_str = "paged"` ### `scale_dtype` `comptime scale_dtype = scale_dtype_` ### `scales_layout` `comptime scales_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scales_shape)` ### `scales_shape` `comptime scales_shape = IntTuple(-1, 2 if PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.is_mla.__bool__().__invert__()._mlir_value else 1, -1, page_size, Int.__init__[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].kv_params.num_heads), PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].head_dim_granularity)` ### `scales_type` `comptime scales_type = LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scale_dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scales_layout, MutAnyOrigin]` ## Methods ### `__init__` `__init__(out self, blocks: LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].scale_dtype, Layout.row_major[6](), MutAnyOrigin]] = None)` ### `get_key_cache` `get_key_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].CacheType` **Returns:** `PagedKVCacheCollection` ### `get_value_cache` `get_value_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_].CacheType` **Returns:** `PagedKVCacheCollection` ### `cache_length` `cache_length(self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## types (Types)
This module contains the types for the key-value cache APIs. The module includes structs implementing several different types of [KV caches](/glossary/ai/kv-cache). This module defines two traits that define the roles of the different structs * `KVCacheT`: Defines the interface for a single (key or value) cache. * `KVCollectionT`: Defines the interface for a pair of caches (keys and values). ## Structs * [​`ContinuousBatchingKVCache`](./ContinuousBatchingKVCache): Wrapper for the ContinuousKVCache of a given layer in the transformer model. * [​`ContinuousBatchingKVCacheCollection`](./ContinuousBatchingKVCacheCollection): This is a "view" of the cache for the given sequences in the batch. * [​`KVCacheStaticParams`](./KVCacheStaticParams): * [​`PagedKVCache`](./PagedKVCache): The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention. * [​`PagedKVCacheCollection`](./PagedKVCacheCollection): ## Traits * [​`KVCacheT`](./KVCacheT): Trait for different KVCache types and implementations. * [​`KVCollectionT`](./KVCollectionT): Trait for a pair of caches (keys and values). ## Functions * [​`padded_depth`](./padded_depth): * [​`swizzle_granularity`](./swizzle_granularity):
--- ## padded_depth
`padded_depth[dtype: DType, swizzle_mode: TensorMapSwizzle, depth: Int]() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## swizzle_granularity
`swizzle_granularity[dtype: DType, swizzle_mode: TensorMapSwizzle]() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## CopyPolicy
The CopyPolicy trait defines requirements needed for a tensor to be copied. These requirements check the compatibility of the source and destination tensors. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `verify_source_tensor` `static verify_source_tensor(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` A static function that verifies the source tensor is compatible with the copy operation. If the tensor is not valid compilation will fail. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor that will be copied from. ### `verify_destination_tensor` `static verify_destination_tensor(dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` A static function that verifies the destination tensor is compatible with the copy operation. If the tensor is not valid compilation will fail. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor that will be copied to. ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The host type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## copy
## Traits * [​`CopyPolicy`](./CopyPolicy): The CopyPolicy trait defines requirements needed for a tensor to be copied.
--- ## Element
`struct Element[dtype: DType, layout: Layout, /, index_type: DType = _get_index_type(layout)]` A wrapper around SIMD types that provides layout-driven vectorized operations. The `Element` struct extends SIMD types with layout-aware load and store operations, enabling efficient vectorized access to multi-dimensional data. It maps between logical tensor coordinates and physical memory locations according to the specified layout. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout describing how elements are organized. * ​index\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the index pointing to each element. ## Fields * ​element\_data (`Element[dtype, layout, index_type].element_data_type`): The actual SIMD data stored in this element. This field contains the vectorized data values that can be processed efficiently using SIMD operations. * ​runtime\_layout (`RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type]`): The runtime layout information for memory access patterns. This field stores the layout information needed to map between logical tensor coordinates and physical memory locations, supporting both compile-time and runtime-determined access patterns. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Stringable`](/mojo/std/builtin/str/Stringable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `element_data_type` `comptime element_data_type = SIMD[dtype, layout.size()]` The SIMD type used to store and process the element data. This type alias defines a SIMD vector with the specified data type and size matching the layout's total element count, enabling efficient vectorized operations. ## Methods ### `__init__` `__init__(out self, element_data: SIMD[dtype, layout.size()])` Initializes an Element with the given SIMD data. **Args:** * ​element\_data ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The SIMD data to initialize the element with. `__init__(out self, element_data: SIMD[dtype, layout.size()], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type])` Initializes an Element with the given SIMD data and runtime layout. **Args:** * ​element\_data ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The SIMD data to initialize the element with. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. ### `load` `static load(ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type] = RuntimeLayout[layout, DType.int32, index_type]()) -> Self` Loads data from memory according to the specified layout. This method loads data from memory using the layout information to determine the memory access pattern. It supports both rank-1 and rank-2 layouts with various stride patterns, optimizing for contiguous memory access when possible. **Args:** * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the memory location to load from. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. **Returns:** `Self`: A new `Element` containing the loaded data. ### `masked_load` `static masked_load(ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type] = RuntimeLayout[layout, DType.int32, index_type]()) -> Self` Loads data from memory with masking for partial loads. This method loads data from memory using the layout information, but also handles cases where the runtime dimensions are smaller than the static layout dimensions. It ensures that only valid memory locations are accessed. **Args:** * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the memory location to load from. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. **Returns:** `Self`: A new `Element` containing the loaded data, with zeros in positions beyond the runtime dimensions. ### `store` `store(self, ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space])` Stores element data to memory according to the specified layout. This method performs a layout-aware store operation, writing data to memory following the access patterns defined by the layout. It optimizes memory writes based on the layout's stride patterns to maximize performance. The method handles different memory layout patterns: * For rank-1 tensors with contiguous memory (stride=1), it uses vectorized stores * For rank-2 tensors with contiguous rows or columns, it uses optimized slice-based stores * For non-contiguous memory layouts, it performs element-by-element stores Unlike `masked_store()`, this method assumes the full static dimensions will be written and does not perform runtime dimension boundary checking. Note: This method is constrained to layouts with rank <= 2. For higher-rank tensors, consider decomposing the operation. **Args:** * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Mutable pointer to the memory location where data will be stored. ### `masked_store` `masked_store(self, ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space])` Stores element data to memory with masking for partial stores. This method performs a layout-aware store operation with boundary checking. It ensures that only valid memory locations are written to when the runtime dimensions are smaller than the static layout dimensions, preventing out-of-bounds memory access. The method optimizes for different memory layouts: * For contiguous memory (stride=1), it uses vectorized stores when possible * For non-contiguous memory, it performs element-by-element stores * For all patterns, it respects runtime dimension bounds Note: This method is constrained to layouts with rank <= 2. For higher-rank tensors, consider decomposing the operation. **Args:** * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the memory location where data will be stored. ### `__str__` `__str__(self) -> String` Returns a string representation of the element. **Returns:** `String`: A string representation of the element's data. ### `write_to` `write_to(self, mut writer: T)` Writes the element to the specified writer. **Args:** * ​writer (`T`): The writer to output the element representation to.
--- ## MemoryElement
`struct MemoryElement[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut=mut], /, address_space: AddressSpace, *, index_type: DType = _get_index_type(layout, address_space)]` Represents data in memory organized according to a specific layout. The `MemoryElement` struct provides a high-level interface for accessing data in memory with a specific layout. It encapsulates a pointer to the memory location and the runtime layout information needed to access the data correctly. This abstraction enables efficient memory operations that respect the underlying memory organization, supporting vectorized loads and stores while handling different memory layouts transparently. ## Parameters * ​mut ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the memory element is mutable. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout describing how elements are organized. * ​origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): The origin of the memory element. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The memory address space where the data is located. * ​index\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the index pointing to each memory element. ## Fields * ​ptr (`UnsafePointer[Scalar[dtype], origin, address_space=address_space]`): Pointer to the memory location where the data is stored. This pointer provides access to the underlying memory with the specified address space and alignment requirements. It points to the first element of the data structure in memory. * ​runtime\_layout (`RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type]`): Runtime layout information used for memory access calculations. This field stores the runtime layout information needed to compute memory offsets for accessing elements according to the specified layout pattern. It handles both compile-time known dimensions and runtime-determined dimensions. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type])` Initializes a `MemoryElement` with the given pointer and runtime layout. **Args:** * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the memory location of the element. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. ### `load` `load(self, out result: Element[dtype, layout, index_type])` Loads data from memory according to the specified layout. This method performs a layout-aware load operation, reading data from memory following the access patterns defined by the layout. It optimizes memory reads based on the layout's stride patterns to maximize performance. The method leverages the underlying `Element.load` implementation which handles different memory layout patterns including contiguous and strided access. **Returns:** [`Element`](/mojo/kernels/layout/element/Element): An `Element` containing the loaded data organized according to the layout. ### `store` `store(self: MemoryElement[dtype, layout, mut_origin, address_space, index_type=index_type], src: Element[dtype, layout, index_type])` Stores element data to the memory location of this MemoryElement. This method performs a layout-aware store operation, writing data to memory following the access patterns defined by the layout. It optimizes memory writes based on the layout's stride patterns to maximize performance. The method delegates to the `Element.store` implementation which handles different memory layout patterns including vectorized stores for contiguous memory and element-by-element stores for non-contiguous layouts. **Args:** * ​src ([`Element`](/mojo/kernels/layout/element/Element)): The `Element` containing the data to store. ### `transfer` `transfer(self: MemoryElement[dtype, layout, mut_origin, address_space, index_type=index_type], src: MemoryElement[dtype, layout, origin, address_space, index_type=index_type])` Transfers data from another `MemoryElement` to this one. This method efficiently transfers data between memory locations with potentially different layouts and data types. It performs the following operations: 1. Loads data from the source `MemoryElement` using its layout 2. Converts the data to the destination data type if necessary 3. Stores the converted data to the destination memory location using its layout This provides a high-performance way to copy and convert data between different memory representations while respecting both source and destination memory layouts. **Args:** * ​src ([`MemoryElement`](/mojo/kernels/layout/element/MemoryElement)): The source `MemoryElement` to transfer data from.
--- ## element (Element)
Provides element-based access to memory using layout-driven vectorization. This module implements efficient memory access patterns for multi-dimensional data using the layout system. It provides abstractions for loading and storing data with specific memory layouts, enabling vectorized operations that respect the underlying memory organization. Key components: * `Element`: A wrapper around SIMD types that provides layout-driven vectorized operations * `MemoryElement`: Represents data in memory organized according to a specific layout These components enable efficient tensor operations by ensuring memory accesses follow optimal patterns defined by the layout system. ## Structs * [​`Element`](./Element): A wrapper around SIMD types that provides layout-driven vectorized operations. * [​`MemoryElement`](./MemoryElement): Represents data in memory organized according to a specific layout.
--- ## layout
Provides layout and layout tensor types, which abstract memory layout for multidimensional data. * The [`Layout`](/mojo/kernels/layout/layout/Layout) type represents a mapping between a set of logical coordinates and a linear index. It can be used, for example, to map logical tensor coordinates to a memory address, or to map GPU threads to tiles of data. * The [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) type is a high-performance tensor with explicit memory layout via a `Layout`. ## Modules * [​`copy`](./copy/): * [​`element`](./element/): Provides element-based access to memory using layout-driven vectorization. * [​`int_tuple`](./int_tuple/): Hierarchical integer tuple data structures for high-performance tensor operations. * [​`layout`](./layout/): Provides a high-performance tensor layout system for memory mapping and indexing. * [​`layout_tensor`](./layout_tensor/): Provides the `LayoutTensor` type for representing multidimensional data. * [​`math`](./math/): Implements math methods that work on layout tensors. * [​`runtime_layout`](./runtime_layout/): Provides the `RuntimeLayout` type and functions for working with it. You can use `RuntimeLayout` to define a layout where the dimensions are not known at compile time. * [​`runtime_tuple`](./runtime_tuple/): Provides the `RuntimeTuple` data structure and related utility functions for handling tuple-like data with both compile-time and runtime elements. `RuntimeTuple` is designed for high-performance tensor operations, supporting efficient manipulation of multi-dimensional data structures like shapes, indices, and coordinates. * [​`swizzle`](./swizzle/): Defines swizzle layouts for optimizing memory access patterns. * [​`tensor_core`](./tensor_core/): Tensor Core Module for High-Performance Matrix Operations * [​`tensor_core_async`](./tensor_core_async/): Tensor Core Async Module * [​`tma_async`](./tma_async/): Tensor Memory Accelerator (TMA) Asynchronous Operations Module
--- ## IntArray
`@register_passable` `struct IntArray` A memory-efficient, register-passable array of integers. `IntArray` provides a low-level implementation of a dynamically-sized integer array with direct memory management. This struct serves as the underlying storage mechanism for `IntTuple` and related data structures, optimized for high-performance tensor operations. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(size: Int = 0) -> Self` Initialize a new owned `IntArray` with the specified size. **Args:** * ​size ([`Int`](/mojo/std/builtin/int/Int)): Number of integers to allocate space for. Defaults to 0. ### `__copyinit__` `__copyinit__(existing: Self) -> Self` Initialize by copying an existing `IntArray`. For owned arrays, this performs a deep copy of the data. **Args:** * ​existing (`Self`): The source array to copy from. ### `__del__` `__del__(deinit self)` Destroy the `IntArray` and free its memory if owned. Only frees memory for owned arrays (positive \_size) to prevent double-free errors with views. ### `__getitem__` `__getitem__(self, idx: Int) -> Int` Access an element at the specified index. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): Zero-based index of the element to access. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value at the specified index. ### `__setitem__` `__setitem__(mut self, idx: Int, value: Int)` Set the value at the specified index. Note: Bounds checking is performed when assertions are enabled (e.g., -D ASSERT=all). **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): Zero-based index of the element to modify. * ​value ([`Int`](/mojo/std/builtin/int/Int)): The integer value to store at the specified index. ### `owning` `owning(self) -> Bool` Check if this `IntArray` owns its memory. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if this array owns its memory (positive \_size), False if it's a view (negative \_size). ### `size` `size(self) -> Int` Get the number of elements in the array. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of elements in the array, regardless of ownership status. ### `copy_from` `copy_from(mut self, offset: Int, source: Self, size: Int)` Copy elements from another `IntArray`. **Args:** * ​offset ([`Int`](/mojo/std/builtin/int/Int)): Destination offset in this array. * ​source (`Self`): Source array to copy from. * ​size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements to copy. `copy_from(mut self, dst_offset: Int, source: Self, src_offset: Int, size: Int)` Copy elements from another IntArray with source offset. **Args:** * ​dst\_offset ([`Int`](/mojo/std/builtin/int/Int)): Destination offset in this array. * ​source (`Self`): Source array to copy from. * ​src\_offset ([`Int`](/mojo/std/builtin/int/Int)): Source offset in the source array. * ​size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements to copy.
--- ## IntTuple
`struct IntTuple` A hierarchical, nested tuple of integers with efficient memory management. IntTuple provides a flexible data structure for representing multi-dimensional shapes, indices, and other nested integer collections. It supports both flat and hierarchical representations with efficient memory sharing. This structure is fundamental for tensor operations, layout specifications, and dimension handling in high-performance computing contexts. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Intable`](/mojo/std/builtin/int/Intable), [`Iterable`](/mojo/std/iter/Iterable), [`Movable`](/mojo/std/builtin/value/Movable), [`Sized`](/mojo/std/builtin/len/Sized), [`Stringable`](/mojo/std/builtin/str/Stringable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = _IntTupleIter[origin_of((muttoimm iterable_origin._mlir_origin))]` The iterator type for IntTuple iteration. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): The origin of the iterable. ### `MinimumValue` `comptime MinimumValue = -65534` Minimum allowed value for integers in an `IntTuple`. This constant defines the lower bound for integer values that can be stored directly in an `IntTuple`. Values below this threshold are reserved for internal use to represent structural information like sub-tuple offsets. ## Methods ### `__init__` `__init__(out self)` Initialize an empty IntTuple. Creates an `IntTuple` with zero elements, which can be used as a starting point for building tuples incrementally with `append` or `extend`. Performance: * Minimal allocation (just a single element for length). * Structure validation performed when assertions are enabled. `__init__(out self, *, num_elems: Int)` Initialize an `IntTuple` with a specified number of uninitialized elements. Creates an `IntTuple` with space for the specified number of elements, but does not initialize the elements themselves. Note: Structure validation performed when assertions are enabled. **Args:** * ​num\_elems ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to allocate space for. `__init__(out self, *elements: Int)` Initialize an `IntTuple` with a variadic list of integers. Creates an `IntTuple` containing the provided integer values. **Args:** * ​\*elements ([`Int`](/mojo/std/builtin/int/Int)): Variable number of integer values to store in the tuple. `__init__(out self, elements: VariadicList[Int])` Initialize an `IntTuple` with a list of integers. Creates an `IntTuple` containing the provided integer values. Notes: * Pre-allocates exact memory needed for efficiency. * Validates that all values are above `MinimumValue`. If any value is less than `MinimumValue`, assertion fails with an error message. * Structure validation performed when assertions are enabled. **Args:** * ​elements ([`VariadicList`](/mojo/std/builtin/variadics/VariadicList)): List of integer values to store in the tuple. `@implicit` `__init__(out self, value: Int)` Initialize an `IntTuple` with a single integer value. Creates an `IntTuple` containing a single integer element. **Args:** * ​value ([`Int`](/mojo/std/builtin/int/Int)): The integer value to store in the tuple. `__init__(out self, *elements: Self, *, __list_literal__: Tuple[] = Tuple[]())` Initialize an `IntTuple` with nested IntTuples. Creates a hierarchical `IntTuple` containing the provided `IntTuple` elements, preserving their nested structure. **Args:** * ​\*elements (`Self`): Variable number of `IntTuple` values to store in the tuple. * ​**list\_literal** ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Specifies that this constructor can be used for list literals. `__init__(out self, *, var _owned: IntArray)` Initialize an `IntTuple` taking the values of an `IntArray`. **Args:** * ​\_owned ([`IntArray`](/mojo/kernels/layout/int_tuple/IntArray)): The `IntArray` to use as storage. `__init__(out self, existing: Self, rng: _StridedRange)` Initialize an `IntTuple` as a slice of an existing `IntTuple`. Creates a new `IntTuple` containing only the elements from the existing `IntTuple` that are specified by the range. Notes: * Preserves nested structure of elements in the slice. * Structure validation performed when assertions are enabled. **Args:** * ​existing (`Self`): The source `IntTuple` to slice from. * ​rng ([`_StridedRange`](/mojo/std/builtin/range/_StridedRange)): The range of indices to include in the new `IntTuple`. `__init__(out self, dimlist: DimList)` Initialize an `IntTuple` from a DimList. Creates an `IntTuple` containing the dimensions from a DimList, handling both defined and undefined dimensions appropriately. Notes: * Converts undefined dimensions to `UNKNOWN_VALUE`. * Validates that all values are above `MinimumValue`. If any value is less than `MinimumValue`, assertion fails with an error message. **Args:** * ​dimlist ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): The DimList containing dimension information. `__init__[IterableType: Iterable](out self, iterable: IterableType)` Initialize an `IntTuple` from a zip iterator. Creates an `IntTuple` by appending each element from the zip iterator. Note: This implementation is not optimized and may be improved in future versions. **Parameters:** * ​IterableType ([`Iterable`](/mojo/std/iter/Iterable)): The type of the iterable. **Args:** * ​iterable (`IterableType`): An iterable containing pairs of elements to append. ### `__copyinit__` `__copyinit__(out self, existing: Self)` Initialize by copying an existing `IntTuple`. Creates a deep copy of the provided `IntTuple`, copying all its data into newly allocated memory. Note: There is a Mojo bug where this method unnecessarily propagates the origin of self to the new copy. **Args:** * ​existing (`Self`): The `IntTuple` to copy from. ### `__getitem__` `__getitem__(self, _idx: Int) -> Self` Retrieves an element at the specified index from the `IntTuple`. Supports negative indexing (e.g., `-1` for the last element). **Args:** * ​\_idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to retrieve. **Returns:** `Self`: An `IntTuple` containing either a single value or a sub-tuple. `__getitem__(self, span: Slice) -> Self` Retrieves a slice of elements from the `IntTuple`. Creates a new `IntTuple` containing the elements specified by the slice. **Args:** * ​span ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): A slice object specifying the range of elements to retrieve. **Returns:** `Self`: A new `IntTuple` containing the specified elements. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Compare two `IntTuple`s lexicographically. This function performs element-wise comparison of two `IntTuple`s and determines if the first is lexicographically less than the second. It compares corresponding elements until it finds a pair where the elements differ. Example: ```mojo from layout.int_tuple import IntTuple var tuple1 = IntTuple(1, 2, 3) var tuple2 = IntTuple(1, 2, 4) var result = tuple1 < tuple2 # Returns True because 3 < 4 ``` **Args:** * ​rhs (`Self`): The other `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if `self` is lexicographically less than `rhs`, False otherwise. ### `__eq__` `__eq__(self, other: Self) -> Bool` Equality operator for `IntTuple`. **Args:** * ​other (`Self`): The `IntTuple` to compare with. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple`s are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Inequality operator for `IntTuple`. **Args:** * ​other (`Self`): The `IntTuple` to compare with. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple`s are not equal, False otherwise. ### `elements_size` `static elements_size(elements: VariadicListMem[IntTuple, is_owned]) -> Int` Calculate the total storage size needed for a list of IntTuples. Computes the sum of sizes for all elements, accounting for both direct integer values and nested sub-tuples. **Args:** * ​elements ([`VariadicListMem`](/mojo/std/builtin/variadics/VariadicListMem)): List of `IntTuple` elements to measure. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total storage size required for all elements. `static elements_size[_origin: ImmutOrigin, n: Int](elements: InlineArray[Pointer[IntTuple, _origin], n], idx: Int) -> Int` Calculate the total storage size needed for IntTuples at a specific index. Computes the sum of sizes for all elements at the given index in an array of `IntTuple` pointers. **Parameters:** * ​\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin tracking for memory safety. * ​n ([`Int`](/mojo/std/builtin/int/Int)): Size of the inline array. **Args:** * ​elements ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Array of pointers to `IntTuple`s. * ​idx ([`Int`](/mojo/std/builtin/int/Int)): Index to access in each `IntTuple`. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total storage size required for all elements at the specified index. ### `owned_copy` `owned_copy(self) -> Self` Create a deep copy of this `IntTuple` with its own memory ownership. This method creates a completely independent copy of the `IntTuple` with newly allocated memory. Unlike `__copyinit__`, this method can be called on an existing instance to create a separate copy. Example: ```mojo from layout import IntTuple var original = IntTuple(1, 2, 3) var copy = original.owned_copy() # Modifying copy will not affect original ``` **Returns:** `Self`: A new `IntTuple` containing the same data as this one but with independent memory ownership. ### `replace_entry` `replace_entry(self, idx: Int, value: Self) -> Self` Replace an entry in the tuple with another `IntTuple`. Creates a new `IntTuple` with the element at the specified index replaced by the provided `IntTuple`. Note: If the index is out of bounds, assertion fails with an error message. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to replace. * ​value (`Self`): The `IntTuple` to insert at the specified index. **Returns:** `Self`: A new `IntTuple` with the replacement applied. `replace_entry(mut self, idx: Int, *, int_value: Int)` Replace an integer value at the specified index in-place. Directly modifies the tuple by replacing the integer value at the given index. This is more efficient than creating a new tuple when only a single value needs to be changed. Note: If the index is out of bounds, assertion fails with an error message. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to replace. * ​int\_value ([`Int`](/mojo/std/builtin/int/Int)): The integer value to insert at the specified index. ### `count_values` `count_values(self) -> Int` Count the total number of integer values in this tuple hierarchy. Recursively traverses the nested tuple structure and counts all integer values. This is useful for determining the size needed for flattened representations. Note: For a flat tuple, this will return the same value as `len(self)`. For nested tuples, it counts all leaf integer values. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total count of integer values in this tuple and all nested tuples. ### `flatten` `flatten(self) -> Self` Flatten a nested `IntTuple` into a single-level `IntTuple`. This function converts a hierarchical `IntTuple` structure into a flat sequence of integer values, preserving the order of elements. **Returns:** `Self`: A new `IntTuple` containing all integer values in a flat structure. ### `product_flatten` `product_flatten(self) -> Self` Coalesces a nested `IntTuple` into a single-level `IntTuple`, by multiplying all the values together. **Returns:** `Self`: A new `IntTuple` containing the products of each top level tuple, in a flat structure. ### `all_known` `all_known(self) -> Bool` Check if all values in this tuple hierarchy are known (not `UNKNOWN_VALUE`). Recursively traverses the nested tuple structure and checks if any value is equal to `UNKNOWN_VALUE`. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if all values in this tuple and nested tuples are known, False if any value is `UNKNOWN_VALUE`. `all_known[start: Int, end: Int](self) -> Bool` Check if all values in this tuple hierarchy are known (not `UNKNOWN_VALUE`). Recursively traverses the nested tuple structure and checks if any value is equal to `UNKNOWN_VALUE`. **Parameters:** * ​start ([`Int`](/mojo/std/builtin/int/Int)): The starting index (inclusive) for the range to check. * ​end ([`Int`](/mojo/std/builtin/int/Int)): The ending index (exclusive) for the range to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if all values in this tuple and nested tuples are known, False if any value is `UNKNOWN_VALUE`. ### `append` `append(mut self, *elements: Self)` Append one or more `IntTuple` elements to this tuple. This method modifies the tuple in-place by adding the provided elements to the end of the tuple. It handles both value tuples and nested tuples. Notes: * This operation requires reallocating the underlying `IntArray` storage to accommodate the new elements, which may impact performance for large tuples. **Args:** * ​\*elements (`Self`): Variable number of `IntTuple` objects to append to this tuple. ### `extend` `extend(mut self, tuple: Self)` Extends this tuple by appending all elements from another tuple. This method modifies the tuple in-place by adding all elements from the provided tuple to the end of this tuple. It efficiently handles both value elements and nested tuples. Notes: * This operation requires reallocating the underlying `IntArray` storage to accommodate the new elements, which may impact performance for large tuples. * If the input tuple is empty, this method returns without making any changes. **Args:** * ​tuple (`Self`): The `IntTuple` whose elements will be appended to this tuple. ### `size` `size(self) -> Int` Returns the total size of the `IntTuple` in memory. For owning tuples, returns the size of the underlying `IntArray`. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total size in memory units. ### `tuple_size` `static tuple_size(data: IntArray) -> Int` Recursively calculates the size of a tuple represented by an `IntArray`. This method traverses the tuple structure, accounting for both direct values and nested sub-tuples to compute the total memory footprint. **Args:** * ​data ([`IntArray`](/mojo/kernels/layout/int_tuple/IntArray)): `IntArray` containing the tuple data. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total size of the tuple in memory units. ### `validate_structure` `validate_structure(self)` Validates the internal structure of the `IntTuple`. Ensures that the actual size of the underlying data matches the computed size based on the tuple's structure. This helps detect memory corruption or implementation errors. Assertion fails with an error message if validation fails. ### `__len__` `__len__(self) -> Int` Returns the number of elements in the `IntTuple`. This is the logical length of the tuple, not its memory size. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of elements in the tuple. ### `__iter__` `__iter__(ref self) -> _IntTupleIter[origin_of((muttoimm self_is_origin))]` Returns an iterator over the elements of the `IntTuple`. This enables iteration through the tuple using for-loops. **Returns:** `_IntTupleIter`: An iterator object for this `IntTuple`. ### `is_value` `is_value(self) -> Bool` Determines if this `IntTuple` represents a single value rather than a tuple. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if this `IntTuple` contains exactly one element that is a value, False otherwise. `is_value(self, i: Int) -> Bool` Determines if the element at the specified index is a value rather than a tuple. Notes: If index is out of bounds, assertion fails with an error message. **Args:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the element at index i is a value, False if it's a tuple. ### `is_tuple` `is_tuple(self) -> Bool` Determines if this `IntTuple` represents a tuple rather than a single value. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if this `IntTuple` is a tuple (not a single value), False otherwise. `is_tuple(self, i: Int) -> Bool` Determines if the element at the specified index is a tuple rather than a value. Notes: This is the complement of is\_value(i). **Args:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the element at index i is a tuple, False if it's a value. ### `value` `value(self) -> Int` Retrieves the value of this `IntTuple` if it represents a single value. This method should only be called if `is_value()` returns True. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value stored in this `IntTuple`. `value(self, i: Int) -> Int` Retrieves the value of the element at the specified index. This method should only be called if `is_value(i)` returns True. Notes: If the element is not a value, the behavior is undefined. **Args:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to retrieve. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value stored at the specified index. ### `tuple` `tuple(ref self) -> ref[self] Self` Returns a reference to this `IntTuple` as a tuple. Notes: This method is used to access the current `IntTuple` as a tuple without creating a copy of the data. **Returns:** `ref`: A reference to this `IntTuple` to avoid unnecessary copying. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of this `IntTuple` to the provided writer. Notes: For single values, writes just the value. For tuples, writes a comma-separated list of elements enclosed in parentheses. **Args:** * ​writer (`T`): The writer to output the string representation to. ### `__str__` `__str__(self) -> String` Returns a string representation of this `IntTuple`. **Returns:** `String`: A string representation of the `IntTuple`, using the `write_to` method. ### `is_equal` `static is_equal(a, b: Self) -> Bool` Compares two `IntTuple`s for equality. Notes: Handles nested tuples and special cases where a single-element tuple is equivalent to its contained value. **Args:** * ​a (`Self`): The first `IntTuple` to compare. * ​b (`Self`): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple`s are equal in structure and values, False otherwise. ### `__repr__` `__repr__(self) -> String` Returns a string representation of this `IntTuple` for debugging. **Returns:** `String`: A string representation of the `IntTuple`, same as `__str__`. ### `__int__` `__int__(self) -> Int` Converts this `IntTuple` to an integer. This method should only be called if `is_value()` returns True. Aborts: If the `IntTuple` is not a single value. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value stored in this `IntTuple`.
--- ## abs
`abs(t: IntTuple) -> IntTuple` Compute the absolute value of each element in an `IntTuple`. This function applies the absolute value operation to each integer in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to transform. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure but with absolute values.
--- ## apply
`apply[func: fn(Int) capturing -> Int](t: IntTuple) -> IntTuple` Apply a function to each integer value in an `IntTuple`. This function recursively applies the given function to each integer value in a potentially nested `IntTuple` structure, preserving the structure. **Parameters:** * ​func (`fn(Int) capturing -> Int`): Function to apply to each integer value. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to transform. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure but with each integer value transformed by the function.
--- ## apply_predicate
`apply_predicate[predicate: fn(IntTuple, IntTuple) -> Bool](a: IntTuple, b: IntTuple) -> Bool` Apply a predicate function recursively to two `IntTuple`s. This function traverses two `IntTuple`s with the same structure and applies a predicate function to corresponding elements. The predicate is applied only to the leaf nodes (integer values). Note: If the structures of the two `IntTuple`s don't match (different nesting or length), the function returns False without applying the predicate. **Parameters:** * ​predicate (`fn(IntTuple, IntTuple) -> Bool`): A function that takes two `IntTuple`s (containing integer values) and returns a boolean result. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the predicate returns True for all corresponding elements and the structures match, False otherwise.
--- ## apply_zip
`apply_zip[func: fn(IntTuple, IntTuple) -> IntTuple](t1: IntTuple, t2: IntTuple) -> IntTuple` Apply a function to pairs of elements from two `IntTuple`s. This function zips two `IntTuple`s together and applies the given function to each pair of elements, creating a new `IntTuple` with the results. **Parameters:** * ​func (`fn(IntTuple, IntTuple) -> IntTuple`): Function that takes two `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each pair. `apply_zip[func: fn(IntTuple, IntTuple) capturing -> IntTuple](t1: IntTuple, t2: IntTuple) -> IntTuple` Apply a capturing function to pairs of elements from two `IntTuple`s. This overload allows the function to capture variables from its environment. **Parameters:** * ​func (`fn(IntTuple, IntTuple) capturing -> IntTuple`): Capturing function that takes two `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each pair. `apply_zip[func: fn(IntTuple, IntTuple, IntTuple) -> IntTuple](t1: IntTuple, t2: IntTuple, t3: IntTuple) -> IntTuple` Apply a function to triplets of elements from three `IntTuple`s. This function zips three `IntTuple`s together and applies the given function to each triplet of elements, creating a new `IntTuple` with the results. **Parameters:** * ​func (`fn(IntTuple, IntTuple, IntTuple) -> IntTuple`): Function that takes three `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. * ​t3 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Third `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each triplet. `apply_zip[func: fn(IntTuple, IntTuple, IntTuple) capturing -> IntTuple](t1: IntTuple, t2: IntTuple, t3: IntTuple) -> IntTuple` Apply a capturing function to triplets of elements from three `IntTuple`s. This overload allows the function to capture variables from its environment. **Parameters:** * ​func (`fn(IntTuple, IntTuple, IntTuple) capturing -> IntTuple`): Capturing function that takes three `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. * ​t3 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Third `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each triplet.
--- ## compact_order
`compact_order(shape: IntTuple, order: IntTuple) -> IntTuple` Create a compact stride based on shape and order. This function generates a stride tuple where lower order numbers imply faster varying strides. The resulting shape and stride form a bijective layout. Performance: * Always inlined for optimal performance in tight loops. * Flattens inputs and re-nests results for consistent behavior. Example: ```mojo from layout import IntTuple from layout.int_tuple import compact_order # Create a compact layout with dimensions (2,3,4,5) and ordering (1,4,3,5) var x = compact_order(IntTuple(2,3,4,5), IntTuple(1,4,3,5)) # returns (1,8,2,24) # Create a compact layout with nested dimensions and corresponding ordering var y = compact_order(IntTuple(2,IntTuple(3,4),5), IntTuple(1,IntTuple(2,3),4)) # returns (1,(2,6),24) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape tuple defining dimensions. * ​order ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The order tuple defining the relative ordering of dimensions. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A stride tuple that creates a compact memory layout according to the specified order.
--- ## compatible
`compatible(a: IntTuple, b: IntTuple) -> Bool` Test if two shapes are compatible for tensor operations. This function checks if shape A is compatible with shape B, meaning: 1. The total size of A and B are the same 2. Any coordinate into A can also be used as a coordinate into B Compatible can also be thought of as a partial order on A and B: A <= B. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if shape A is compatible with shape B, False otherwise.
--- ## congruent
`congruent(a: IntTuple, b: IntTuple) -> Bool` Test if two `IntTuple`s have the same hierarchical structure. This function checks if two `IntTuple`s have identical nesting patterns, regardless of the actual integer values they contain. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if both `IntTuple`s have the same hierarchical structure, False otherwise.
--- ## crd2idx
`crd2idx(crd: IntTuple, shape: IntTuple) -> Int` Map a logical coordinate to a linear index. This function converts a multi-dimensional coordinate to a linear index based on the shape. It uses default strides computed from the shape. **Args:** * ​crd ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The coordinate tuple to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The linear index corresponding to the coordinate. `crd2idx(crd: IntTuple, shape: IntTuple, _stride: IntTuple) -> Int` Map a logical coordinate to a linear index with custom strides. This function converts a multi-dimensional coordinate to a linear index based on the shape and stride information. If no stride is provided, it computes default strides from the shape. The function handles various input combinations: * Tuple coordinates with tuple shapes and strides * Single integer coordinate with tuple shapes and strides * Single integer coordinate with single integer shape and stride Aborts: ``` - If coordinate and shape dimensions don't match. - If shape and stride dimensions don't match. - If input type combinations are invalid. ``` **Args:** * ​crd ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The coordinate(s) to convert, can be a single value or a tuple of coordinates. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array, can be a single value or a tuple of dimensions. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Optional custom strides, defaults to row-major strides if not provided. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The linear index corresponding to the coordinate.
--- ## depth
`depth(src: IntTuple) -> Int` Calculates the maximum nesting depth of an `IntTuple`. This function recursively traverses the `IntTuple` structure to determine its maximum nesting depth. A scalar value has depth 0, a flat tuple has depth 1, and nested tuples increase the depth accordingly. Example: ```mojo from layout import IntTuple, depth print(depth(IntTuple(1))) # prints 0 print(depth(IntTuple(1, 2))) # prints 1 print(depth((IntTuple(1, 2)))) # prints 2 ``` **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to measure the depth of. **Returns:** [`Int`](/mojo/std/builtin/int/Int): An integer representing the maximum nesting depth.
--- ## fill_like
`fill_like(src: IntTuple, val: Int) -> IntTuple` Creates an `IntTuple` with the same structure as the source but filled with a specified value. This function recursively traverses the source `IntTuple` and creates a new `IntTuple` with identical structure, but with all leaf values replaced by the specified value. **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` whose structure will be copied. * ​val ([`Int`](/mojo/std/builtin/int/Int)): The integer value to fill the new `IntTuple` with. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as src but filled with val.
--- ## flatten
`flatten(t: IntTuple) -> IntTuple` Flatten a nested `IntTuple` into a single-level `IntTuple`. This function converts a hierarchical `IntTuple` structure into a flat sequence of integer values, preserving the order of elements. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The nested `IntTuple` to flatten. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing all integer values in a flat structure.
--- ## idx2crd
`idx2crd(idx: IntTuple, shape: IntTuple) -> IntTuple` Converts a linear index to a coordinate tuple within a given shape. This function splits an index into a coordinate within a Shape via a colexicographical enumeration of coordinates in Shape. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the coordinates corresponding to the linear index. `idx2crd(idx: IntTuple, shape: IntTuple, _stride: IntTuple) -> IntTuple` Converts a linear index to a coordinate tuple within a given shape using custom strides. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Custom strides to use for the conversion. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the coordinates corresponding to the linear index.
--- ## idx2crd2
`idx2crd2(idx: IntTuple, shape: IntTuple, _stride: IntTuple) -> IntTuple` Convert a linear index to coordinates. This function handles the actual conversion logic for different input combinations. Notes: * Handles four cases: tuple-tuple-tuple, tuple-int-int, int-tuple-tuple, and int-int-int. * When input shapes don't match, `abort()` will be called. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Custom strides to use for the conversion. If empty, strides are computed from the shape using prefix\_product. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new IntTuple containing the coordinates corresponding to the linear index.
--- ## int_tuple
Hierarchical integer tuple data structures for high-performance tensor operations. This module provides a flexible, memory-efficient implementation of nested integer tuples optimized for tensor shape, stride, and index operations in high-performance computing. The core data structures support both flat and hierarchical representations with efficient memory sharing and zero-copy views. Key components: * `IntArray`: Low-level register-passable array with direct memory management * `IntTuple`: Hierarchical nested tuple with efficient memory layout and operations * Utility functions for tensor shape manipulation, coordinate transformations, and layout operations Performance features: * Register-passable data structures for optimal compiler optimizations * Zero-copy views for efficient memory sharing * Specialized memory layout for nested structures * Optimized algorithms for common tensor operations Common operations: * Shape manipulation: `flatten`, `to_nest`, `apply`, `product`, `sum` * Coordinate transformations: `idx2crd`, `crd2idx` * Layout operations: `compact_order`, `prefix_product` * Structural comparisons: `congruent`, `compatible`, `weakly_congruent` Example usage: ```mojo from layout import IntTuple from layout.int_tuple import flatten, compact_order, size # Create nested tuples var shape = IntTuple(2, IntTuple(3, 4), 5) # Represents shape (2, (3, 4), 5) # Flatten a nested tuple var flat = flatten(shape) # Results in (2, 3, 4, 5) # Create compact strides for a given shape and order var order = IntTuple(1, IntTuple(2, 3), 4) var strides = compact_order(shape, order) # Results in (1, (2, 6), 24) # Calculate total size (product of all elements) var total_size = size(shape) # Results in 120 ``` ## `comptime` values ### `IntList` `comptime IntList = List[Int]` A type alias for a List of integers with ownership. This alias defines a List that contains Int values and has ownership of its data. It's used throughout the module for storing and manipulating collections of integers, particularly for operations like permutations and indices. ### `UNKNOWN_VALUE` `comptime UNKNOWN_VALUE = -1` Special value indicating an unknown or unspecified dimension. This constant is used throughout the `IntTuple` system to represent dimensions that are not known at compile time or have not been specified. ## Structs * [​`IntArray`](./IntArray): A memory-efficient, register-passable array of integers. * [​`IntTuple`](./IntTuple): A hierarchical, nested tuple of integers with efficient memory management. ## Functions * [​`abs`](./abs): Compute the absolute value of each element in an `IntTuple`. * [​`apply`](./apply): Apply a function to each integer value in an `IntTuple`. * [​`apply_predicate`](./apply_predicate): Apply a predicate function recursively to two `IntTuple`s. * [​`apply_zip`](./apply_zip): Apply a function to pairs of elements from two `IntTuple`s. * [​`compact_order`](./compact_order): Create a compact stride based on shape and order. * [​`compatible`](./compatible): Test if two shapes are compatible for tensor operations. * [​`congruent`](./congruent): Test if two `IntTuple`s have the same hierarchical structure. * [​`crd2idx`](./crd2idx): Map a logical coordinate to a linear index. * [​`depth`](./depth): Calculates the maximum nesting depth of an `IntTuple`. * [​`fill_like`](./fill_like): Creates an `IntTuple` with the same structure as the source but filled with a specified value. * [​`flatten`](./flatten): Flatten a nested `IntTuple` into a single-level `IntTuple`. * [​`idx2crd`](./idx2crd): Converts a linear index to a coordinate tuple within a given shape. * [​`idx2crd2`](./idx2crd2): Convert a linear index to coordinates. * [​`inner_product`](./inner_product): Compute the inner product of two `IntTuple`s. * [​`is_flat`](./is_flat): Check if an `IntTuple` is flat. * [​`is_int`](./is_int): Check if an `IntTuple` represents a single integer value. * [​`is_tuple`](./is_tuple): Check if an `IntTuple` represents a nested tuple. * [​`mul`](./mul): Multiply each element in an `IntTuple` by a scalar value. * [​`prefix_product`](./prefix_product): Compute the exclusive prefix product of an `IntTuple`. * [​`product`](./product): Calculate the product of all values in an `IntTuple`. * [​`product_each`](./product_each): Compute the product of elements in each sub-tuple of an `IntTuple`. * [​`propagate_unknown`](./propagate_unknown): Propagates unknown dimensions from the target `IntTuple` to the source `IntTuple`. * [​`reduce`](./reduce): Apply a reduction function to an `IntTuple` with an initial value. * [​`reverse`](./reverse): Reverses the order of elements in an `IntTuple`, recursively. * [​`shallow_apply`](./shallow_apply): Apply a function to each top-level element of an `IntTuple`. * [​`shape_div`](./shape_div): Performs division operation between shape tuples. * [​`signum`](./signum): Calculate the sign of an integer. * [​`size`](./size): Calculate the total size (product of all elements) of an `IntTuple`. * [​`sorted`](./sorted): Sort an IntTuple using the provided comparison function. * [​`sum`](./sum): Calculate the sum of all values in an `IntTuple`. * [​`to_index_list`](./to_index_list): Converts an IntTuple to a flattened IndexList with the same values. * [​`to_nest`](./to_nest): Nests a flat `IntTuple` according to the structure of a nested `IntTuple`. * [​`to_unknown`](./to_unknown): Create an `IntTuple` with the same structure but filled with `UNKNOWN_VALUE`. * [​`tuple_max`](./tuple_max): Calculate the maximum value in an `IntTuple`. * [​`tuple_min`](./tuple_min): Compute the element-wise minimum of two `IntTuple`s. * [​`weakly_compatible`](./weakly_compatible): Test if shape A is weakly compatible with shape B. * [​`weakly_congruent`](./weakly_congruent): Test if two IntTuples have similar hierarchical structures.
--- ## inner_product
`inner_product(a: IntTuple, b: IntTuple) -> Int` Compute the inner product of two `IntTuple`s. For flat tuples, this is the sum of element-wise products. For nested tuples, the function recurses into corresponding nested elements. Note: If the input tuples have different lengths, assertion fails. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The inner product as an `Int`.
--- ## is_flat
`is_flat(t: IntTuple) -> Bool` Check if an `IntTuple` is flat. This function checks if the `IntTuple` is flat, meaning it has no nested elements. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple` is flat, False otherwise.
--- ## is_int
`is_int(t: IntTuple) -> Bool` Check if an `IntTuple` represents a single integer value. This function determines whether the given `IntTuple` contains a single integer value rather than a nested tuple structure. Example: ```mojo from layout.int_tuple import is_int, IntTuple var single_value = IntTuple(5) var nested_tuple = IntTuple(1, 2, 3) var result1 = is_int(single_value) # Returns True var result2 = is_int(nested_tuple) # Returns False ``` **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple` contains a single integer value, False if it's a nested tuple.
--- ## is_tuple
`is_tuple(t: IntTuple) -> Bool` Check if an `IntTuple` represents a nested tuple. This function determines whether the given `IntTuple` contains nested elements rather than a single integer value. It is the complement of the `is_int` function. Example: ```mojo from layout.int_tuple import is_tuple, IntTuple var single_value = IntTuple(5) var nested_tuple = IntTuple(1, 2, 3) var result1 = is_tuple(single_value) # Returns False var result2 = is_tuple(nested_tuple) # Returns True ``` **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `IntTuple` contains nested elements, False if it's a single integer value.
--- ## mul
`mul(lhs: IntTuple, rhs: Int) -> IntTuple` Multiply each element in an `IntTuple` by a scalar value. This function creates a new `IntTuple` where each element (at any nesting level) is multiplied by the provided integer value. **Args:** * ​lhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be multiplied. * ​rhs ([`Int`](/mojo/std/builtin/int/Int)): The scalar integer to multiply each element by. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as the input but with all elements multiplied by the scalar value.
--- ## prefix_product
`prefix_product(a: IntTuple) -> IntTuple` Compute the exclusive prefix product of an `IntTuple`. This is a convenience wrapper that initializes the prefix product with 1. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The input `IntTuple` to compute the prefix product for. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the exclusive prefix product of the input. `prefix_product(a: IntTuple, init: Int) -> IntTuple` Compute the exclusive prefix product of an `IntTuple` with an initial value. This function delegates to the implementation in prefix\_product2. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The input `IntTuple` to compute the prefix product for. * ​init ([`Int`](/mojo/std/builtin/int/Int)): The initial value(s) for the prefix product, defaults to 1. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the exclusive prefix product of the input.
--- ## product
`product(t: IntTuple) -> Int` Calculate the product of all values in an `IntTuple`. This function recursively computes the product of all integer values in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to multiply. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The product of all integer values, or `UNKNOWN_VALUE` if any value in the tuple is `UNKNOWN_VALUE`.
--- ## product_each
`product_each(t: IntTuple) -> IntTuple` Compute the product of elements in each sub-tuple of an `IntTuple`. For each immediate child of the input tuple, this function computes the product of all elements within that child. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` containing sub-tuples. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` where each element is the product of the corresponding sub-tuple in the input.
--- ## propagate_unknown
`propagate_unknown(src: IntTuple, target: IntTuple) -> IntTuple` Propagates unknown dimensions from the target `IntTuple` to the source `IntTuple`. This function creates a new `IntTuple` by combining the source and target `IntTuple`s, preserving unknown dimensions (UNKNOWN\_VALUE) from the target while using values from the source for known dimensions. **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` containing known dimension values. * ​target ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The target `IntTuple` that may contain unknown dimensions (UNKNOWN\_VALUE). **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with unknown dimensions from target and known dimensions from src.
--- ## reduce
`reduce[reducer: fn(a: Int, b: IntTuple) capturing -> Int](t: IntTuple, initializer: Int) -> Int` Apply a reduction function to an `IntTuple` with an initial value. This function iterates through each element of the `IntTuple` and applies the provided reduction function cumulatively, starting with the initializer. **Parameters:** * ​reducer (`fn(a: Int, b: IntTuple) capturing -> Int`): A function that combines the accumulated result with the next element. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to reduce. * ​initializer ([`Int`](/mojo/std/builtin/int/Int)): The initial value for the reduction operation. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The final accumulated result after applying the reduction function to all elements in the `IntTuple`.
--- ## reverse
`reverse(src: IntTuple) -> IntTuple` Reverses the order of elements in an `IntTuple`, recursively. This function reverses the top-level elements of the `IntTuple` and recursively reverses any nested `IntTuple`s. Example: ```mojo from layout.int_tuple import IntTuple, reverse var t = IntTuple(1, 2, IntTuple(3, 4)) var reversed = reverse(t) # returns ((4, 3), 2, 1) ``` **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` to reverse. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with elements in reversed order.
--- ## shallow_apply
`shallow_apply[func: fn(IntTuple) -> Int](t: IntTuple) -> IntTuple` Apply a function to each top-level element of an `IntTuple`. Unlike `apply()`, this function only operates on the immediate children of the input tuple without recursing into nested tuples. **Parameters:** * ​func (`fn(IntTuple) -> Int`): Function that takes an `IntTuple` and returns an `Int`. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be transformed. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the function applied to each top-level element.
--- ## shape_div
`shape_div[check: Bool = False](a: IntTuple, b: IntTuple) -> IntTuple` Performs division operation between shape tuples. Handles four cases: 1. tuple-tuple: Performs shape\_div element-wise when dimensions match 2. tuple-int: Folds the division of b across each element of a Example: `shape_div((4,5,6),40)` -> `shape_div((1,5,6),10)` -> `shape_div((1,1,6),2)` -> `(1,1,3)` 3. int-tuple: Returns `shape_div(a, product(b))` 4. int-int: Enforces the divisibility condition `a % b == 0 || b % a == 0` when possible Returns `a / b` with rounding away from `0` (that is, `1` or `-1` when `a < b`) Notes: * When tuple sizes don't match in the tuple-tuple case, `abort()` will be called. * When values are incompatible (neither divides the other) in the int-int case, `abort()` will be called. **Parameters:** * ​check ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to check for incompatible shapes. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dividend `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The divisor `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the result of the division operation
--- ## signum
`signum(a: Int) -> Int` Calculate the sign of an integer. This function determines the sign of the input integer and returns a corresponding indicator value. Example: ```mojo from layout.int_tuple import signum var result1 = signum(5) # Returns 1 var result2 = signum(-10) # Returns -1 var result3 = signum(0) # Returns 0 ``` **Args:** * ​a ([`Int`](/mojo/std/builtin/int/Int)): The integer value to determine the sign of. **Returns:** [`Int`](/mojo/std/builtin/int/Int): 1 if `a` > 0, -1 if `a` < 0, 0 if `a` == 0.
--- ## size
`size(a: IntTuple) -> Int` Calculate the total size (product of all elements) of an `IntTuple`. This function computes the product of all integer values in the `IntTuple`, regardless of nesting level. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be multiplied together. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The product of all elements in the `IntTuple`.
--- ## sorted
`sorted[cmp: fn(IntTuple, IntTuple) -> Bool = __lt__](tuple: IntTuple) -> IntTuple` Sort an IntTuple using the provided comparison function. This function implements a merge sort algorithm to efficiently sort the elements of an IntTuple. The sorting is stable and has `O(n log n)` time complexity. **Parameters:** * ​cmp (`fn(IntTuple, IntTuple) -> Bool`): A comparison function that takes two `IntTuple` elements and returns True if the first should come before the second. Defaults to the `lt` function which performs lexicographical ordering. **Args:** * ​tuple ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to be sorted. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the same elements as the input but sorted according to the comparison function.
--- ## sum
`sum(t: IntTuple) -> Int` Calculate the sum of all values in an `IntTuple`. This function recursively computes the sum of all integer values in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to sum. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The sum of all integer values, or `UNKNOWN_VALUE` if any value in the tuple is `UNKNOWN_VALUE`.
--- ## to_index_list
`to_index_list[rank: Int, element_type: DType = DType.int64](t: IntTuple) -> IndexList[rank, element_type=element_type]` Converts an IntTuple to a flattened IndexList with the same values. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the resulting IndexList. * ​element\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Element type, must be integer type. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` defining the values. **Returns:** `IndexList`: An IndexList filled with the values of t.
--- ## to_nest
`to_nest(nested: IntTuple, flat: IntTuple) -> IntTuple` Nests a flat `IntTuple` according to the structure of a nested `IntTuple`. This function reshapes a flat sequence of values into a hierarchical structure that matches the pattern of a template nested `IntTuple`. Example: ```mojo from layout import IntTuple from layout.int_tuple import to_nest var result = to_nest(IntTuple(2, IntTuple(3, 4), 5), IntTuple(1, 2, 3, 4)) # returns IntTuple(1, (2, 3), 4) ``` **Args:** * ​nested ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The template `IntTuple` defining the desired structure. * ​flat ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The flat `IntTuple` containing the values to be nested. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the values from flat arranged in the structure of nested.
--- ## to_unknown
`to_unknown(t: IntTuple) -> IntTuple` Create an `IntTuple` with the same structure but filled with `UNKNOWN_VALUE`. This function preserves the hierarchical structure of the input `IntTuple` but replaces all integer values with `UNKNOWN_VALUE`. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The template `IntTuple` defining the structure. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as t but with all values replaced by `UNKNOWN_VALUE`.
--- ## tuple_max
`tuple_max(t: IntTuple) -> Int` Calculate the maximum value in an `IntTuple`. This function recursively finds the maximum integer value in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to search. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The maximum integer value found in the tuple.
--- ## tuple_min
`tuple_min(a: IntTuple, b: IntTuple) -> IntTuple` Compute the element-wise minimum of two `IntTuple`s. This function compares corresponding elements of two `IntTuple`s and returns a new `IntTuple` containing the minimum value at each position. Aborts: If the input tuples have different lengths. Note: If either input contains `UNKNOWN_VALUE`, the result will be `UNKNOWN_VALUE`. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with each element being the minimum of the corresponding elements in a and b.
--- ## weakly_compatible
`weakly_compatible(a: IntTuple, b: IntTuple) -> Bool` Test if shape A is weakly compatible with shape B. A shape A is weakly compatible with shape B if there exists a shape C congruent to A such that compatible(elem\_scale(A,C), B). This establishes a partial order relation between shapes where A <= B. Specifically, this checks if the size of B is divisible by the size of A, which is a necessary condition for weak compatibility. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if shape A is weakly compatible with shape B, False otherwise.
--- ## weakly_congruent
`weakly_congruent(a: IntTuple, b: IntTuple) -> Bool` Test if two IntTuples have similar hierarchical structures. This function establishes a partial order relation between IntTuples based on their hierarchical structure. It's less strict than congruent. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First IntTuple to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second IntTuple to compare. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if a's structure is compatible with b's structure, False otherwise.
--- ## Layout (Layout)
`struct Layout` Represents a memory layout for multi-dimensional data. The Layout struct is the primary implementation of the LayoutTrait, providing a concrete representation of memory layouts using shape and stride information. It maps between logical coordinates and linear memory indices, enabling efficient access to multi-dimensional data. A Layout consists of: * shape: Defines the dimensions of the logical coordinate space * stride: Defines the step sizes in memory for each dimension The Layout struct supports various operations including: * Creation of row-major and column-major layouts * Conversion between coordinates and indices * Composition with other layouts * Iteration over sub-layouts Layouts can be hierarchical, with nested shapes and strides, allowing for complex memory access patterns like blocked or tiled layouts. ## Fields * ​shape (`IntTuple`): The dimensions of the layout. This field defines the size of each dimension in the logical coordinate space. For example, a shape of (3, 4) represents a 3x4 grid of elements. * ​stride (`IntTuple`): The memory step sizes for each dimension. This field defines how many elements to skip in memory when moving one unit in each dimension. For example, in a row-major 3x4 layout, the strides might be (4, 1), meaning moving one unit in the first dimension requires skipping 4 elements in memory, while moving one unit in the second dimension requires skipping 1 element. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Iterable`](/mojo/std/iter/Iterable), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/std/builtin/value/Movable), [`Sized`](/mojo/std/builtin/len/Sized), [`Stringable`](/mojo/std/builtin/str/Stringable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `has_shape` `comptime has_shape = True` Indicates whether the layout has a valid shape. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = _LayoutIter[origin_of((muttoimm iterable_origin._mlir_origin))]` The iterator type for Layout iteration. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Initializes an empty layout with no dimensions. Creates a layout with empty shape and stride tuples, which can be populated later using append operations. `__init__(out self, shape: IntTuple)` Initializes a layout with the given shape and column-major strides. Creates a layout with the specified shape and automatically calculates column-major strides (where the first dimension varies fastest in memory). **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dimensions of the layout. `__init__(out self, shape: IntTuple, stride: IntTuple)` Initializes a layout with the given shape and stride. Creates a layout with explicitly specified shape and stride values. If an empty stride is provided, column-major strides are calculated. **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dimensions of the layout. * ​stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The memory step size for each dimension, or empty for column-major. ### `__getitem__` `__getitem__(self, index: Int) -> Self` Returns a sub-layout for the specified dimension. **Args:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to extract. **Returns:** `Self`: A Layout containing the shape and stride for the specified dimension. ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if this layout is equal to another layout. Two layouts are considered equal if they have identical shape and stride tuples. **Args:** * ​other (`Self`): The layout to compare with. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the layouts are equal, False otherwise. ### `idx2crd` `idx2crd(self, idx: IntTuple) -> IntTuple` Converts a linear index to logical coordinates. This is the inverse operation of the **call** method, mapping from a memory index back to the corresponding logical coordinates. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): The logical coordinates corresponding to the given index. ### `col_major` `static col_major(*dims: Int) -> Self` Creates a column-major layout with the specified dimensions. In a column-major layout, the first dimension varies fastest in memory, which is the default layout in languages like Fortran and MATLAB. Example: ```mojo from layout import Layout # Create a 3x4 column-major layout var layout = Layout.col_major(3, 4) # Result: Layout with shape (3,4) and stride (1,3) ``` **Args:** * ​\*dims ([`Int`](/mojo/std/builtin/int/Int)): Variable number of dimension sizes. **Returns:** `Self`: A column-major Layout with the specified dimensions `static col_major(shape: IntTuple) -> Self` Creates a column-major layout with the specified shape. In a column-major layout, the first dimension varies fastest in memory, which is the default layout in languages like Fortran and MATLAB. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Create a 3x4 column-major layout var layout = Layout.col_major(IntTuple(3, 4)) # Result: Layout with shape (3,4) and stride (1,3) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple specifying the dimensions. **Returns:** `Self`: A column-major Layout with the specified shape `static col_major[rank: Int](dims: DimList) -> Self` Creates a col-major layout from a DimList with compile-time rank. This method creates a col-major layout where the first dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from layout.layout import DimList # Create a col-major layout with compile-time rank var dims = DimList(3, 4) var layout = Layout.col_major[2](dims) # Result: Layout with shape (3,4) and stride (1,3) ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​dims ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): A DimList containing the dimensions of the layout. **Returns:** `Self`: A col-major Layout with the specified dimensions and computed strides. `static col_major[rank: Int](tuple: IndexList[rank]) -> Self` Creates a col-major layout from a IndexList with compile-time rank. This method creates a col-major layout where the first dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from utils import IndexList # Create a row-major layout with compile-time rank var idx_list = IndexList[2](3, 4) var layout = Layout.col_major[2](idx_list) # Result: Layout with shape (3,4) and stride (1,3) ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​tuple ([`IndexList`](/mojo/std/utils/index_/IndexList)): An IndexList containing the dimensions of the layout. **Returns:** `Self`: A col-major Layout with the specified dimensions and computed strides. ### `row_major` `static row_major(*dims: Int) -> Self` Creates a row-major layout with the specified dimensions. In a row-major layout, the last dimension varies fastest in memory, which is the default layout in languages like C, C++, and Python. Example: ```mojo from layout import Layout # Create a 3x4 row-major layout var layout = Layout.row_major(3, 4) # Result: Layout with shape (3,4) and stride (4,1) ``` **Args:** * ​\*dims ([`Int`](/mojo/std/builtin/int/Int)): Variable number of dimension sizes. **Returns:** `Self`: A row-major Layout with the specified dimensions `static row_major[rank: Int](dims: DimList) -> Self` Creates a row-major layout from a DimList with compile-time rank. This method creates a row-major layout where the last dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from layout.layout import DimList # Create a row-major layout with compile-time rank var dims = DimList(3, 4) var layout = Layout.row_major[2](dims) # Result: Layout with shape (3,4) and stride (4,1) ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​dims ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): A DimList containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified dimensions and computed strides. `static row_major[rank: Int](tuple: IndexList[rank]) -> Self` Creates a row-major layout from a IndexList with compile-time rank. This method creates a row-major layout where the last dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from utils import IndexList # Create a row-major layout with compile-time rank var idx_list = IndexList[2](3, 4) var layout = Layout.row_major[2](idx_list) # Result: Layout with shape (3,4) and stride (4,1) ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​tuple ([`IndexList`](/mojo/std/utils/index_/IndexList)): An IndexList containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified dimensions and computed strides. `static row_major[rank: Int]() -> Self` Creates a row-major layout with unknown values for each axis from a compile-time rank. Example: ```mojo from layout import Layout var layout = Layout.row_major[2]() # Result: Layout with shape (UNKNOWN_VALUE, UNKNOWN_VALUE) ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Returns:** `Self`: A row-major Layout with the given rank. `static row_major(shape: IntTuple) -> Self` Creates a row-major layout from an IntTuple of dimensions. In a row-major layout, the last dimension varies fastest in memory. This method computes the appropriate strides for a row-major layout given the input shape. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Create a row-major layout from a shape tuple var shape = IntTuple(3, 4) var layout = Layout.row_major(shape) # Result: Layout with shape (3,4) and stride (4,1) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified shape and computed strides. ### `make_shape_unknown` `make_shape_unknown[axis: Int = -1](self) -> Self` Creates a new Layout with unknown shape dimensions. This method creates a copy of the current Layout but marks either all dimensions or a specific dimension as unknown, while preserving the original strides. This is useful for tiling tensors with runtime sizes where the tile's shape is unknown but the memory layout (strides) remains constant. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Mark all dimensions as unknown var layout = Layout(IntTuple(2, 3)) var unknown = layout.make_shape_unknown() # Result: Layout with shape (?, ?) and original strides # Mark only first dimension as unknown var partial = layout.make_shape_unknown[0]() # Result: Layout with shape (?, 3) and original strides ``` **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The dimension to mark as unknown. If UNKNOWN\_VALUE (default), all dimensions are marked as unknown. **Returns:** `Self`: A new Layout with the specified dimension(s) marked as unknown and original strides preserved. ### `__str__` `__str__(self) -> String` Converts the layout to a string representation. **Returns:** `String`: A string representation of the layout in the format "(shape:stride)". ### `write_to` `write_to(self, mut writer: T)` Writes the layout to the specified writer. Formats the layout as "(shape:stride)" and writes it to the provided writer. **Args:** * ​writer (`T`): The writer to output the layout representation to. ### `__len__` `__len__(self) -> Int` Returns the number of dimensions in the layout. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of elements in the shape tuple. ### `__iter__` `__iter__(ref self) -> _LayoutIter[origin_of((muttoimm self_is_origin))]` Returns an iterator over the layout's dimensions. Each iteration yields a Layout containing the shape and stride for one dimension. **Returns:** `_LayoutIter`: An iterator over the layout's dimensions. ### `size` `size(self) -> Int` Returns the total number of elements in the layout's domain. Calculates the product of all dimensions in the shape. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total number of elements in the layout. ### `cosize` `cosize(self) -> Int` Returns the size of the memory region spanned by the layout. Calculates the maximum memory index plus one, representing the total memory footprint required by the layout. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the memory region required by the layout. ### `rank` `rank(self) -> Int` Returns the number of dimensions in the layout. This is equivalent to **len** and returns the number of elements in the shape tuple. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of dimensions in the layout. ### `__call__` `__call__(self, idx: IntTuple) -> Int` Maps logical coordinates to a linear memory index. This is the core functionality of a layout, converting multi-dimensional coordinates to a linear memory location. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The logical coordinates to map. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The linear memory index corresponding to the given coordinates. ### `append` `append(mut self, item: Self)` Appends another layout to this layout. This method adds the shape and stride from the provided layout to this layout, effectively increasing its dimensionality. **Args:** * ​item (`Self`): The layout to append to this layout. ### `all_dims_known` `all_dims_known(self) -> Bool` Checks if all dimensions in the layout have known values. A dimension is considered unknown if its shape or stride is set to the special `UNKNOWN_VALUE` constant. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if all dimensions have known shape and stride values, False otherwise. ### `known_shape` `known_shape(self) -> Bool` Checks if all shape dimensions in the layout have known values. A dimension is considered unknown if its shape is set to the special `UNKNOWN_VALUE` constant. This method only checks shapes, not strides. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if all shape dimensions have known values, False otherwise. ### `transpose` `transpose(self) -> Self` Transposes the layout by reversing the order of dimensions. For an n-dimensional layout, this reverses the order of both shapes and strides. For nested layouts, only the top-level dimensions are transposed, not the hierarchical structure within nested tuples. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Simple 2D transpose (row-major to column-major) var layout = Layout.row_major(3, 4) # shape (3,4), stride (4,1) var transposed = layout.transpose() # shape (4,3), stride (1,4) # 3D transpose var layout3d = Layout.row_major(2, 3, 4) # shape (2,3,4), stride (12,4,1) var trans3d = layout3d.transpose() # shape (4,3,2), stride (1,4,12) # Nested layout - only top level transposed var nested = Layout( IntTuple(IntTuple(2, 3), 4), IntTuple(IntTuple(12, 4), 1) ) var trans_nested = nested.transpose() # Result: shape (4, (2,3)), stride (1, (12,4)) ``` **Returns:** `Self`: A new Layout with transposed dimensions.
--- ## LayoutTrait
Defines the interface for mapping between logical coordinates and memory indices. The `LayoutTrait` provides a common interface for all layout types, including basic layouts, swizzles, and composed layouts. It enables mapping from multi-dimensional logical coordinates to linear memory indices, which is essential for tensor operations. Implementations of this trait must provide methods for: 1. Mapping coordinates to indices via the `__call__` method 2. Calculating the total size of the layout's domain 3. Calculating the size of the layout's codomain (memory footprint) 4. Indicating whether the layout has a valid shape This trait serves as the foundation for the layout system, allowing different layout implementations to be used interchangeably in algorithms. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `has_shape` `comptime has_shape` Indicates whether the layout has a valid shape. Layouts and ComposedLayouts with at least one Layout have valid shapes and can be used in layout algebra. Swizzles don't have shapes and should be excluded from layout algebra. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `__call__` `__call__(self: _Self, index: IntTuple) -> Int` Maps a logical coordinate to a linear memory index. **Args:** * ​index ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple representing the logical coordinates to map. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The linear memory index corresponding to the given coordinates. ### `size` `size(self: _Self) -> Int` Returns the total number of elements in the layout's domain. For a layout with shape (m, n), this returns m \* n, representing the total number of valid coordinates in the layout. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total number of elements in the layout. ### `cosize` `cosize(self: _Self) -> Int` Returns the size of the memory region spanned by the layout. For a layout with shape `(m, n)` and stride `(r, s)`, this returns `(m-1)*r + (n-1)*s + 1`, representing the memory footprint. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the memory region required by the layout. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## MakeLayoutList
`MakeLayoutList(var v0: Layout, var v1: Layout) -> LayoutList` Creates a list containing two layouts. This is a convenience function for creating a LayoutList with two elements. **Args:** * ​v0 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to include in the list. * ​v1 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to include in the list. **Returns:** `LayoutList`: A LayoutList containing the two provided layouts.
--- ## MakeTileLayoutList
`MakeTileLayoutList[*tile_sizes: Int]() -> LayoutList` Creates a list of layouts for tiling operations. This function creates a list of simple layouts, each with a shape from the provided tile\_sizes and a stride of 1. These layouts can be used for tiling operations. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): Variable number of integer tile dimensions. **Returns:** `LayoutList`: A LayoutList containing layouts for each tile size.
--- ## apply_tiler
`apply_tiler[func: fn(var Layout, var Layout) -> Layout](var layout_a: Layout, tiler: List[Layout]) -> Layout` Applies a layout transformation function to each element of a layout with a tiler. This utility function applies the specified transformation function to each corresponding pair of elements from the layout and tiler list. It's a generic mechanism for implementing various tiling operations. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import apply_tiler, logical_divide # Apply logical_divide to each element of a layout with a tiler var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2), IntTuple(1, 2))) var result = apply_tiler[logical_divide](base^, tilers) ``` **Parameters:** * ​func (`fn(var Layout, var Layout) -> Layout`): A function that takes two layouts and returns a transformed layout. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to transform. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts to use in the transformation. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout resulting from applying the transformation function to each pair.
--- ## blocked_product
`blocked_product(var layout_a: Layout, var layout_b: Layout, coalesce_output: Bool = False) -> Layout` Creates a blocked layout by combining two layouts. This function creates a hierarchical blocked layout by combining an inner (block) and an outer (base) layout. The result is a layout where each element of the outer layout is replaced by a block defined by the inner layout. This is particularly useful for creating tiled layouts for efficient cache utilization in tensor operations like matrix multiplication. Example: ```mojo from layout import Layout from layout.layout import blocked_product # Create a 2x3 matrix layout var matrix = Layout.row_major(2, 3) # Define 2x2 blocks var block = Layout.row_major(2, 2) # Create a blocked layout with 2x2 blocks var blocked = blocked_product(block^, matrix^) ``` Output: ```plaintext (((2, 2), (2, 3)):((2, 12), (1, 4))) 0 1 2 3 4 5 +----+----+----+----+----+----+ 0 | 0 | 1 | 4 | 5 | 8 | 9 | +----+----+----+----+----+----+ 1 | 2 | 3 | 6 | 7 | 10 | 11 | +----+----+----+----+----+----+ 2 | 12 | 13 | 16 | 17 | 20 | 21 | +----+----+----+----+----+----+ 3 | 14 | 15 | 18 | 19 | 22 | 23 | +----+----+----+----+----+----+ ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): Inner layout. The layout for an individual block, or tile. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): Outer layout. The layout of the tiles in the output layout. * ​coalesce\_output ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to coalesce the output layout. Default is False. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the blocked structure
--- ## coalesce
`coalesce(layout: Layout, keep_rank: Bool = False) -> Layout` Simplifies a layout by combining dimensions with contiguous strides. This function reduces the rank of a layout by merging dimensions that have contiguous memory layouts, resulting in a simpler but equivalent layout. Example: ```mojo from layout import Layout, IntTuple from layout.layout import coalesce # A layout with shape (2, (1, 4)) and stride (1, (4, 2)) can be coalesced var layout = Layout(IntTuple(2, IntTuple(1, 4)), IntTuple(1, IntTuple(4, 2))) var coalesced = coalesce(layout) # Result: Layout with shape (8) and stride (1) ``` **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to coalesce. * ​keep\_rank ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, maintains the original rank of the layout. Default is False. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A simplified layout with reduced rank where possible.
--- ## complement
`complement(layout: Layout, size: Int = 1) -> Layout` Computes the complement layout for a given layout. This function creates a layout that represents the "gaps" or complementary structure of the input layout. It's useful for creating hierarchical layouts where you need to fill in the spaces between existing layout elements. Example: ```mojo from layout import Layout, IntTuple from layout.layout import complement # Compute the complement of a layout var base = Layout(IntTuple(2, 3), IntTuple(3, 1)) var comp = complement(base, 10) # Result: A layout that fills the gaps in the original layout ``` **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The input layout to compute the complement for. * ​size ([`Int`](/mojo/std/builtin/int/Int)): The total size of the memory region to consider. Defaults to 1. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the complement of the input layout.
--- ## composition
`composition(var layout_a: Layout, var layout_b: Layout) -> Layout` Composes two layouts to create a new layout. This function creates a new layout by composing two layouts, where the first layout defines the outer structure and the second layout defines the inner structure. The new layout is compatible with `layout_b` (that is, it has the same `size` and every set of coordinates in `layout_b` has an equivalent in the new layout). You can think of `layout_b` as selecting a subset of elements from `layout_a`. Example: ```mojo from layout.layout import Layout, IntTuple from layout.layout import composition # Compose a row-major layout with a tiling layout var base = Layout.row_major(6, 8) var tiling = Layout(IntTuple(3, 2), IntTuple(1, 3)) var composed = composition(base^, tiling^) # Result: A layout that represents a 3x2 tile from # layout_a ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The outer layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The inner layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the composition of the two layouts. `composition(var layout_a: Layout, tiler: List[Layout]) -> Layout` Composes a layout with a list of layouts to create a hierarchical layout. This function creates a new layout by composing each element of the first layout with the corresponding element in the tiler list. If the tiler list is shorter than the layout, the remaining elements from the layout are appended unchanged. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import composition # Compose a layout with a list of tiling layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2), IntTuple(1, 2))) tilers.append(Layout(IntTuple(3, 3), IntTuple(1, 3))) var composed = composition(base^, tilers^) # Result: A layout with hierarchical tiling based on the tiler list ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to compose with the tiler. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts to compose with the base layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the composition of the base layout with the tiler.
--- ## cosize
`cosize(l: Layout) -> Int` Returns the size of the memory region spanned by the layout. This is a standalone function equivalent to the Layout.cosize() method. **Args:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to calculate the cosize for. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the memory region required by the layout.
--- ## downcast
`downcast(layout: Layout, factor: Int) -> Layout` Splits elements in a layout to create a finer layout without changing the total number of elements so that the alignment is preserved. This function is useful for converting between different data type granularities, such as from uint128 to bf16. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to downcast. * ​factor ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to split into. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with adjusted shape and stride for the finer granularity.
--- ## expand_modes_alike
`expand_modes_alike(shape_a: IntTuple, stride_a: IntTuple, shape_b: IntTuple, stride_b: IntTuple) -> InlineArray[IntTuple, 3]` Aligns two shape-stride pairs to have the same hierarchical structure. This function is used to make two layouts compatible for operations by ensuring they have the same hierarchical structure, expanding scalar values into tuples as needed. **Args:** * ​shape\_a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first shape tuple. * ​stride\_a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first stride tuple. * ​shape\_b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second shape tuple. * ​stride\_b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second stride tuple. **Returns:** [`InlineArray`](/mojo/std/collections/inline_array/InlineArray): An array containing three tuples: the common shape, the expanded stride\_a, and the expanded stride\_b. `expand_modes_alike(layout_a: Layout, layout_b: Layout) -> InlineArray[Layout, 2]` Aligns two layouts to have the same hierarchical structure. This function tiles both layouts so they mirror each other's structure, making them compatible for operations that require matching hierarchies. Example: Given layouts with different structures: * layout\_0: (((3, (5, 2)), 4):((1, (24, 12)), 3)) * layout\_1: ((30, (2, 2)):(2, (60, 1))) The result would be two layouts with matching structures: * (((3, (5, 2)), (2, 2)):((1, (24, 12)), (3, 6))) * (((3, (5, 2)), (2, 2)):((2, (6, 30)), (60, 1))) ```mojo from layout import Layout, IntTuple from layout.layout import expand_modes_alike comptime layout_0 = Layout( IntTuple(IntTuple(3, IntTuple(5, 2)), 4), IntTuple(IntTuple(1, IntTuple(24, 12)), 3), ) comptime layout_1 = Layout( IntTuple(30, IntTuple(2, 2)), IntTuple(2, IntTuple(60, 1)) ) comptime uc = expand_modes_alike(layout_0, layout_1) print(uc[0]) # (((3, (5, 2)), (2, 2)):((1, (24, 12)), (3, 6))) print(uc[1]) # (((3, (5, 2)), (2, 2)):((2, (6, 30)), (60, 1))) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to align. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to align. **Returns:** [`InlineArray`](/mojo/std/collections/inline_array/InlineArray): An array containing two layouts with matching hierarchical structures.
--- ## expand_strides
`expand_strides(shape: IntTuple, stride: Int) -> IntTuple` Expands a scalar stride into a stride tuple matching a shape tuple. This function creates a stride tuple that matches the structure of a shape tuple, with each stride value calculated based on the cumulative product of shape dimensions. **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape tuple to match. * ​stride ([`Int`](/mojo/std/builtin/int/Int)): The base stride value to expand. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A stride tuple matching the structure of the shape tuple.
--- ## format_layout
`format_layout[W: Writer](layout: Layout, mut writer: W)` Formats a 2D layout as a table and writes it to the specified writer. This function creates a visual representation of a 2D layout as a table showing the memory indices for each logical coordinate. **Parameters:** * ​W ([`Writer`](/mojo/std/format/Writer)): Type parameter representing a Writer implementation. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The 2D layout to format. * ​writer (`W`): The writer to output the formatted layout to.
--- ## hierarchical_unzip
`hierarchical_unzip(layout_a: Layout, tiler: List[Layout]) -> Layout` Hierarchically unzips a layout according to a list of layouts. This function creates a hierarchical layout by unzipping the first layout according to the layouts in the tiler list. It's useful for decomposing a layout into hierarchical components for more efficient memory access patterns or to enable specialized tensor operations. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import hierarchical_unzip # Create a layout to unzip var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = hierarchical_unzip(base, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be unzipped. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts defining the unzipping patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical unzipping with components from both the original layout and the tiler layouts. `hierarchical_unzip(layout_a: Layout, layout_b: Layout) -> Layout` Hierarchically unzips a layout according to another layout. This function creates a hierarchical layout by unzipping the first layout according to the second layout. It's a fundamental operation for decomposing a layout into hierarchical components, which enables more efficient memory access patterns for various tensor operations. Example: ```mojo from layout import Layout, IntTuple from layout.layout import hierarchical_unzip # Create layouts var base = Layout.row_major(6, 8) var pattern = Layout(IntTuple(2, 2)) var result = hierarchical_unzip(base, pattern) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be unzipped. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the unzipping pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical unzipping of layout\_a according to the pattern defined by layout\_b.
--- ## layout (3)
Provides a high-performance tensor layout system for memory mapping and indexing. The layout module implements a comprehensive system for describing memory layouts of multi-dimensional tensors, enabling efficient mapping between logical tensor coordinates and physical memory locations. This is a critical component for high-performance tensor operations in machine learning and scientific computing. These low-level primitives require careful use to avoid errors. Understanding the relationship between tensor shapes, strides, and memory layout is essential for effective use. Key components: * `LayoutTrait`: Core trait defining the interface for all layout types * `Layout`: Primary struct implementing memory layout with shape and stride information * Layout algebra: Functions for composing, dividing, and transforming layouts * Tiling operations: Functions for hierarchical decomposition of layouts Performance features: * Zero-cost abstractions for mapping between logical and physical indices * Support for both compile-time and runtime-determined shapes * Efficient memory access patterns through layout transformations * Hierarchical tiling for cache-friendly memory access Common use cases: * Defining memory layouts for tensors with different storage formats (row-major, column-major) * Implementing efficient tensor operations with optimal memory access patterns * Supporting hardware-specific memory layouts for accelerators * Enabling zero-copy tensor views and reshaping operations Example: ```mojo from layout import Layout, IntTuple from layout.layout import blocked_product # Create a 3x4 row-major layout var layout = Layout.row_major(3, 4) # Access the memory location for logical coordinates (1, 2) var memory_idx = layout([1, 2]) # Create a tiled layout for blocked matrix multiplication var tiled = blocked_product(layout^, Layout([2, 2])) ``` ## `comptime` values ### `LayoutList` `comptime LayoutList = List[Layout]` Type alias for a list of Layout objects. ## Structs * [​`Layout`](./Layout): Represents a memory layout for multi-dimensional data. ## Traits * [​`LayoutTrait`](./LayoutTrait): Defines the interface for mapping between logical coordinates and memory indices. ## Functions * [​`apply_tiler`](./apply_tiler): Applies a layout transformation function to each element of a layout with a tiler. * [​`blocked_product`](./blocked_product): Creates a blocked layout by combining two layouts. * [​`coalesce`](./coalesce): Simplifies a layout by combining dimensions with contiguous strides. * [​`complement`](./complement): Computes the complement layout for a given layout. * [​`composition`](./composition): Composes two layouts to create a new layout. * [​`cosize`](./cosize): Returns the size of the memory region spanned by the layout. * [​`downcast`](./downcast): Splits elements in a layout to create a finer layout without changing the total number of elements so that the alignment is preserved. * [​`expand_modes_alike`](./expand_modes_alike): Aligns two shape-stride pairs to have the same hierarchical structure. * [​`expand_strides`](./expand_strides): Expands a scalar stride into a stride tuple matching a shape tuple. * [​`format_layout`](./format_layout): Formats a 2D layout as a table and writes it to the specified writer. * [​`hierarchical_unzip`](./hierarchical_unzip): Hierarchically unzips a layout according to a list of layouts. * [​`is_contiguous_dim`](./is_contiguous_dim): Checks if a flat layout is contiguous in a specific dimension. * [​`is_row_major`](./is_row_major): Checks if a layout has row-major ordering for the specified rank. * [​`logical_divide`](./logical_divide): Divides a layout into blocks according to another layout. * [​`logical_product`](./logical_product): Creates a product of two layouts. * [​`make_layout`](./make_layout): Creates a composite layout by concatenating multiple layouts. * [​`make_ordered_layout`](./make_ordered_layout): Creates a layout with strides ordered according to a specified traversal order. * [​`MakeLayoutList`](./MakeLayoutList): Creates a list containing two layouts. * [​`MakeTileLayoutList`](./MakeTileLayoutList): Creates a list of layouts for tiling operations. * [​`print_layout`](./print_layout): Prints a 2D layout to the standard output. * [​`right_inverse`](./right_inverse): Creates a right inverse of a layout. * [​`size`](./size): Returns the total number of elements in the layout's domain. * [​`sublayout`](./sublayout): Creates a sublayout by selecting specific dimensions from a layout. * [​`tile_to_shape`](./tile_to_shape): Creates a layout by tiling a base layout to match a target shape. * [​`upcast`](./upcast): Fuses consecutive elements in a layout to create a coarser layout. * [​`zip_modes`](./zip_modes): Combines corresponding modes from two layouts. * [​`zipped_divide`](./zipped_divide): Divides a layout into blocks according to another layout.
--- ## is_contiguous_dim
`is_contiguous_dim(layout: Layout, dim: Int) -> Bool` Checks if a flat layout is contiguous in a specific dimension. This function checks if a flat layout is contiguous in a specified dimension, considering both positive strides and zero strides with a single element. The latter case is necessary for coalesced layouts. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to check. * ​dim ([`Int`](/mojo/std/builtin/int/Int)): The dimension to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the layout is contiguous in the specified dimension, False otherwise.
--- ## is_row_major
`is_row_major[rank: Int](layout: Layout) -> Bool` Checks if a layout has row-major ordering for the specified rank. A row-major layout has strides that decrease from left to right, with the rightmost dimension having a stride of 1. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The expected rank of the layout. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the layout has row-major ordering for the specified rank, False otherwise.
--- ## logical_divide
`logical_divide(layout_a: Layout, _layout_b: Layout) -> Layout` Divides a layout into blocks according to another layout. This function creates a hierarchical layout by dividing the first layout according to the second layout. It's useful for creating blocked or tiled representations of tensors. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​\_layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the division pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division. `logical_divide(layout_a: Layout, tiler: List[Layout]) -> Layout` Divides a layout into blocks according to a list of layouts. This is a variant of logical\_divide that works with a list of layouts for more complex tiling patterns. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts defining the division patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division.
--- ## logical_product
`logical_product(var _layout_a: Layout, var layout_b: Layout) -> Layout` Creates a product of two layouts. This function creates a hierarchical layout by taking the logical product of two layouts. It's a fundamental operation for creating blocked or tiled layouts. **Args:** * ​\_layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the logical product of the two layouts. `logical_product(var layout_a: Layout, tiler: List[Layout]) -> Layout` Creates a product of a layout with a list of layouts. This is a variant of logical\_product that works with a list of layouts for more complex tiling patterns. It applies the logical\_product operation to each element of the layout with the corresponding element in the tiler list. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import logical_product # Create a product of a layout with a list of layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = logical_product(base^, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to create products with. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts defining the product patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the logical product with the tiler layouts.
--- ## make_layout
`make_layout(*layouts: Layout) -> Layout` Creates a composite layout by concatenating multiple layouts. This function combines multiple layouts into a single layout by concatenating their shapes and strides. The resulting layout represents a hierarchical structure where each input layout becomes a component of the output layout. Example: ```mojo from layout import Layout, IntTuple from layout.layout import make_layout var layout1 = Layout(IntTuple(2, 3), IntTuple(3, 1)) var layout2 = Layout(IntTuple(4, 5), IntTuple(5, 1)) var combined = make_layout(layout1, layout2) # Result: Layout with shape ((2, 3), (4, 5)) and stride ((3, 1), (5, 1)) ``` **Args:** * ​\*layouts ([`Layout`](/mojo/kernels/layout/layout/Layout)): Variable number of `Layout` objects to combine. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new Layout with concatenated shapes and strides from the input layouts. `make_layout(layout_a: Layout, layout_b: Layout) -> Layout` Creates a composite layout from two layouts. This is a specialized version of make\_layout that takes exactly two layouts and combines them into a single layout. This function exists as a workaround for compiler limitations. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to include in the composite. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to include in the composite. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new `Layout` with concatenated shapes and strides from the input layouts.
--- ## make_ordered_layout
`make_ordered_layout(shape: IntTuple, order: IntTuple) -> Layout` Creates a layout with strides ordered according to a specified traversal order. This function generates a compact (bijective) layout where the stride values follow the traversal order specified by the order parameter. This allows creating layouts with custom memory traversal patterns while maintaining a compact memory representation. Example: ```mojo from layout import IntTuple, Layout from layout.layout import make_ordered_layout # Create a layout with shape (2,3,4,5) where dimensions are traversed # in the order: dim0, dim3, dim2, dim1 var layout = make_ordered_layout( IntTuple(2, 3, 4, 5), IntTuple(1, 4, 3, 2) ) # Result: Layout with shape (2,3,4,5) and stride (1,40,10,2) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the layout. * ​order ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The traversal order priority (lower values indicate higher priority). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A `Layout` with the specified shape and strides ordered according to the traversal order.
--- ## print_layout
`print_layout(layout: Layout)` Prints a 2D layout to the standard output. This function visualizes a 2D layout by printing a formatted table showing the memory indices for each logical coordinate. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The 2D layout to print.
--- ## right_inverse
`right_inverse(layout: Layout) -> Layout` Creates a right inverse of a layout. The right inverse of a layout maps memory indices back to logical coordinates. This is useful for converting between different memory layouts. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to invert. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the right inverse of the input layout.
--- ## size (Layout)
`size(l: Layout) -> Int` Returns the total number of elements in the layout's domain. This is a standalone function equivalent to the Layout.size() method. **Args:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to calculate the size for. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total number of elements in the layout.
--- ## sublayout
`sublayout(layout: Layout, *modes: Int) -> Layout` Creates a sublayout by selecting specific dimensions from a layout. This function extracts a subset of dimensions from a layout to create a new layout with lower rank. For example, from a 3D layout, you could extract a 2D layout containing only the first and third dimensions. Example: From a layout with shape (3,4,5), sublayout(layout, 0, 2) would create a layout with shape (3,5). **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The source layout to extract dimensions from. * ​\*modes ([`Int`](/mojo/std/builtin/int/Int)): The indices of dimensions to include in the sublayout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout containing only the specified dimensions.
--- ## tile_to_shape
`tile_to_shape(var tile: Layout, target_shape: IntTuple, order: Optional[IntTuple] = None) -> Layout` Creates a layout by tiling a base layout to match a target shape. This function creates a hierarchical layout by repeating a tile layout to match a target shape. It calculates how many times the tile needs to be repeated in each dimension to reach the target shape, and creates a tiler layout with this information. Example: ```mojo from layout import Layout, IntTuple from layout.layout import tile_to_shape # Create a 2x2 tile layout var tile = Layout.row_major(2, 2) # Tile it to create a 6x4 layout var tiled = tile_to_shape(tile^, IntTuple(6, 4)) # Result: A layout with 3x2 tiles of size 2x2 each ``` **Args:** * ​tile ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to be tiled. * ​target\_shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The desired final shape to tile to. * ​order ([`Optional`](/mojo/std/collections/optional/Optional)): Optional memory ordering for the tiler layout. If None, defaults to column-major ordering. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the tiled structure that matches the target shape.
--- ## upcast
`upcast[check: Bool = True](var layout: Layout, factor: Int) -> Layout` Fuses consecutive elements in a layout to create a coarser layout. This function is useful for converting between different data type granularities, such as from bytes to larger data types like bfloat16 or tf32. **Parameters:** * ​check ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to check for incompatible factors. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to upcast. * ​factor ([`Int`](/mojo/std/builtin/int/Int)): The number of consecutive elements to fuse into one. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with adjusted shape and stride for the coarser granularity.
--- ## zip_modes
`zip_modes(layout_a: Layout, layout_b: Layout) -> Layout` Combines corresponding modes from two layouts. This function creates a new layout by combining corresponding dimensions from two layouts. If a dimension in layout\_b has a non-positive shape, the corresponding dimension from layout\_a is used directly. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with combined dimensions from both input layouts.
--- ## zipped_divide
`zipped_divide(layout_a: Layout, layout_b: Layout) -> Layout` Divides a layout into blocks according to another layout. This function creates a hierarchical layout by dividing the first layout according to the second layout. It's an alias for hierarchical\_unzip that provides a more intuitive name for the division operation. This is useful for creating blocked or tiled representations of tensors. Example: ```mojo from layout import Layout, IntTuple from layout.layout import zipped_divide # Create layouts var base = Layout.row_major(6, 8) var pattern = Layout(IntTuple(2, 2)) var result = zipped_divide(base, pattern) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the division pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division of layout\_a according to layout\_b. `zipped_divide(layout_a: Layout, tiler: List[Layout]) -> Layout` Divides a layout into blocks according to a list of layouts. This function creates a hierarchical layout by dividing the first layout according to the layouts in the tiler list. It's an alias for hierarchical\_unzip that provides a more intuitive name for the division operation when working with multiple tiling patterns. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import zipped_divide # Create layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = zipped_divide(base, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​tiler ([`List`](/mojo/std/collections/list/List)): A list of layouts defining the division patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division of layout\_a according to the patterns in tiler.
--- ## LayoutTensor
`@register_passable(trivial)` `struct LayoutTensor[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut=mut], /, *, address_space: AddressSpace = AddressSpace.GENERIC, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, address_space), linear_idx_type: DType = _get_index_type(layout, address_space), masked: Bool = False, alignment: Int = align_of[dtype]()]` A high-performance tensor with explicit memory layout and hardware-optimized access patterns. `LayoutTensor` provides a powerful abstraction for multi-dimensional data with precise control over memory organization. It supports various memory layouts (row-major, column-major, tiled), hardware-specific optimizations, and efficient parallel access patterns. Example: ```mojo from layout import Layout, LayoutTensor # Create tensor on CPU using InlineArray to allocate storage space. var storage = InlineArray[Float32, 5 * 4](uninitialized=True) var tensor_5x4 = LayoutTensor[DType.float32, Layout.row_major(5, 4)](storage) ``` ## Parameters * ​mut ([`Bool`](/mojo/std/builtin/bool/Bool)): The inferred mutability of the underlying pointer. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the underlying pointer. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the tensor. * ​origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): The origin of the underlying pointer. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The address space of the underlying pointer. * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of each element in the tensor. * ​layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of each dimension of runtime layout. * ​linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the index pointing to memory locations. * ​masked ([`Bool`](/mojo/std/builtin/bool/Bool)): If true the tensor is masked and runtime layouts determine the shape. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Alignment of the data pointer. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin]`): Pointer to the underlying memory buffer containing the tensor data. This pointer respects the specified address space, alignment, mutability, and origin tracking for memory safety and performance optimization. * ​runtime\_layout (`LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeLayoutType`): Runtime representation of the tensor's memory layout. Handles both compile-time and runtime-determined dimensions, enabling efficient mapping between logical tensor coordinates and physical memory locations. * ​runtime\_element\_layout (`LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeElementLayoutType`): Runtime representation of each element's internal layout. Used when elements themselves have structure, such as in blocked or tiled layouts. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable), [`_Expable`](/mojo/std/math/math/_Expable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AddressSpaceCastType` `comptime AddressSpaceCastType[address_space: AddressSpace = address_space] = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for address-space-cast result tensors. #### Parameters * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The target address space for the result tensor. ### `BitcastType` `comptime BitcastType[new_dtype: DType, /, address_space: AddressSpace = address_space, element_layout: Layout = element_layout] = LayoutTensor[new_dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for bitcast result tensors. #### Parameters * ​new\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The target data type to cast to. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The address space for the result tensor. * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout for the result tensor. ### `CoalesceType` `comptime CoalesceType[element_layout: Layout] = LayoutTensor[dtype, coalesce(layout, False), origin, address_space=address_space, element_layout=element_layout]` Type alias for coalesced result tensors. #### Parameters * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout for the coalesced tensor. ### `CompositionType` `comptime CompositionType[rhs_layout: Layout, dst_layout: Layout = composition(layout, rhs_layout)] = LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for composed layout tensor types. #### Parameters * ​rhs\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to compose with. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The resulting composed layout. ### `CornerCoordsType` `comptime CornerCoordsType = IndexList[len[IntTuple](flatten(layout.shape)), element_type=layout_int_type]` Index list type for corner coordinates. ### `device_type` `comptime device_type = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` The device-side type representation. ### `DistributeType` `comptime DistributeType[threads_layout: Layout, axis: Optional[Int] = None] = LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False]` Type alias for distributed tensor types. #### Parameters * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout describing thread distribution. * ​axis ([`Optional`](/mojo/std/collections/optional/Optional)): Optional axis to distribute along. ### `DynamicSplitType` `comptime DynamicSplitType[axis: Int = 0] = LayoutTensor[dtype, layout.make_shape_unknown[axis](), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for dynamic split result tensors. #### Parameters * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which to split. ### `element_size` `comptime element_size = element_layout.size()` The number of scalar values in each element of the tensor. ### `element_type` `comptime element_type = SIMD[dtype, LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_size]` The SIMD vector type used for vectorized operations on tensor elements. ### `FlattenedType` `comptime FlattenedType = LayoutTensor[dtype, Layout(IntTuple(-1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for flattened tensor types. ### `GenericAddressSpaceLayoutTensor` `comptime GenericAddressSpaceLayoutTensor = LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor variant using generic address space. ### `GenericLayoutTensorType` `comptime GenericLayoutTensorType = LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor type with generic address space. ### `idx_list_t` `comptime idx_list_t[rank: Int = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank] = IndexList[rank, element_type=linear_idx_type]` Type alias for index lists of the tensor's rank. #### Parameters * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the index list. ### `MutableAnyType` `comptime MutableAnyType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Mutable LayoutTensor type with MutAnyOrigin. ### `num_strides` `comptime num_strides = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeLayoutType.StrideType.scalar_length` Number of stride values in the layout. ### `OriginCastType` `comptime OriginCastType[mut: Bool, //, origin: Origin[mut=mut]] = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for origin-cast result tensors. #### Parameters * ​mut ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the result tensor is mutable. * ​origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): The origin for the result tensor. ### `rank` `comptime rank = layout.rank()` The number of dimensions in the tensor's layout. ### `ReshapeType` `comptime ReshapeType[dst_layout: Layout] = LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for reshaped tensor types. #### Parameters * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. ### `RuntimeElementLayoutType` `comptime RuntimeElementLayoutType = RuntimeLayout[element_layout, element_type=DType.int32, linear_idx_type=linear_idx_type]` Type alias for the runtime element layout. ### `RuntimeLayoutType` `comptime RuntimeLayoutType = RuntimeLayout[layout, element_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for the runtime layout. ### `ShapeVectorizedType` `comptime ShapeVectorizedType[origin: ImmutOrigin, vector_shape: IntTuple, linear_vectorize: Bool] = LayoutTensor[dtype, coalesce(LayoutTensor._tuple_divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](vector_shape, linear_vectorize)[1], True), origin, address_space=address_space, element_layout=LayoutTensor._tuple_divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](vector_shape, linear_vectorize)[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for shape-vectorized tensor types. #### Parameters * ​origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): The origin of the result tensor. * ​vector\_shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of each vector unit. * ​linear\_vectorize ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to vectorize in a linear manner. ### `SIMDTileType` `comptime SIMDTileType[tile_size: Int] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size, simd_width_of[dtype]()]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_size, simd_width_of[dtype]()](), alignment=alignment]` Type alias for SIMD-sized tile tensors. #### Parameters * ​tile\_size ([`Int`](/mojo/std/builtin/int/Int)): The size of the tile along the tiled axis. ### `SIMDVectorizedType` `comptime SIMDVectorizedType = LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, 1, simd_width_of[dtype]()]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, 1, simd_width_of[dtype]()]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Result type for SIMD-width vectorization. ### `SliceType` `comptime SliceType[d0_slice: Slice, d1_slice: Slice] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 2D slice result tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the first dimension. * ​d1\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the second dimension. ### `SliceType1D` `comptime SliceType1D[d0_slice: Slice, slice_indices: IndexList[1], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 1)] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, slice_indices.__getitem__[1, DType.int64, Int](0)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 1D slice result tensors from higher-rank tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the selected dimension. * ​slice\_indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): Index of the dimension to slice. * ​\_\_offset\_dims ([`Int`](/mojo/std/builtin/int/Int)): Number of fixed dimensions. ### `SliceType2D` `comptime SliceType2D[d0_slice: Slice, d1_slice: Slice, slice_indices: IndexList[2], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice, slice_indices.__getitem__[2, DType.int64, Int](0), slice_indices.__getitem__[2, DType.int64, Int](1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 2D slice result tensors from higher-rank tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the first selected dimension. * ​d1\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the second selected dimension. * ​slice\_indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): Indices of the two dimensions to slice. * ​\_\_offset\_dims ([`Int`](/mojo/std/builtin/int/Int)): Number of fixed dimensions. ### `SplitElementType` `comptime SplitElementType[count: Int, axis: Int = 0] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size=(layout.shape[axis].value() // count), axis=axis]()[0], MutAnyOrigin, address_space=address_space, element_layout=element_layout, alignment=alignment]` Type alias for split element tensors. #### Parameters * ​count ([`Int`](/mojo/std/builtin/int/Int)): Number of portions to split into. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which to split. ### `StackTensorType` `comptime StackTensorType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor type for stack-allocated tensors. ### `StaticSplitType` `comptime StaticSplitType[count: Int, axis: Int = 0] = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size=(layout.shape[axis].value() // count), axis=axis]()[0], MutAnyOrigin, address_space=address_space, element_layout=element_layout, alignment=alignment], count]` Type alias for static split result tuples. #### Parameters * ​count ([`Int`](/mojo/std/builtin/int/Int)): Number of portions to split into. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which to split. ### `storage_size` `comptime storage_size = (size_of[dtype]() * layout.size())` Total storage size in bytes for the tensor data. ### `TiledIteratorType` `comptime TiledIteratorType[*tile_sizes: Int, *, axis: Int = 0] = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes]()]` Type alias for tiled iterator types. #### Parameters * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each tile along each axis. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which to iterate. ### `TileType` `comptime TileType[*tile_sizes: Int] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment]` The tile type returned by the `tile()` method given the specified set of tile sizes. #### Parameters * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. ### `TransposeType` `comptime TransposeType = LayoutTensor[dtype, layout.transpose(), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Result type for transpose operations. ### `VectorizedType` `comptime VectorizedType[*vector_shape: Int] = LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for vectorized tensor types. #### Parameters * ​\*vector\_shape ([`Int`](/mojo/std/builtin/int/Int)): The shape of each vector unit along each axis. ## Methods ### `__init__` `__init__(span: Span[Scalar[dtype], origin]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span`. **Constraints:** Layout must be fully static. **Args:** * ​span ([`Span`](/mojo/std/memory/span/Span)): The `Span` pointing to the underlying data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(span: Span[Scalar[dtype], origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span` and a runtime layout for the tensor. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** * Element layout must be fully static. **Args:** * ​span ([`Span`](/mojo/std/memory/span/Span)): The `Span` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(span: Span[Scalar[dtype], origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span`, a runtime layout of the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** * Runtime layout and `LayoutTensor` must have the same bitwidth and index type. **Args:** * ​span ([`Span`](/mojo/std/memory/span/Span)): The `Span` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin]) -> Self` Create a `LayoutTensor` with an `UnsafePointer`. **Constraints:** Layout must be fully static. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The `UnsafePointer` pointing to the underlying data. `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> Self` Create a `LayoutTensor` with an `UnsafePointer` and a runtime layout for the tensor. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** Element layout must be fully static. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The UnsafePointer pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> Self` Create a `LayoutTensor` with an `UnsafePointer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The `UnsafePointer` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. `__init__(ref[origin] device_buffer: DeviceBuffer[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer`. The layout must have statically known dimensions. Note that the device buffer memory is on the accelerator device (GPU global memory). Code running on the CPU can use the [`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext) to allocate a `DeviceBuffer` and use that to construct a `LayoutTensor` that can be accessed on the GPU. You cannot directly access data in the `DeviceBuffer` or `LayoutTensor` from the CPU. The following example shows a typical pattern for using `DeviceBuffer` to construct a `LayoutTensor` that you can use on the GPU. ```mojo from gpu.host import DeviceContext, DeviceBuffer from layout import Layout, LayoutTensor comptime dtype = DType.float32 var ctx = DeviceContext() # Allocate buffers var dev_buf = ctx.enqueue_create_buffer[dtype](16) var host_buf = ctx.enqueue_create_host_buffer[dtype](16) # Ensure buffers have been created ctx.synchronize() # Initialize host buffer and copy to device buffer for i in range(16): host_buf[i] = i ctx.enqueue_copy(dev_buf, host_buf) # Create LayoutTensor to use on device comptime layout = Layout.row_major(4, 4) var tensor = LayoutTensor[dtype, layout](dev_buf) ... ``` **Constraints:** * Layout must be fully static. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/std/gpu/host/device_context/DeviceBuffer)): Contains the underlying data to point to. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref[origin] host_buffer: HostBuffer[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer`. The layout must have statically known dimensions. The resulting tensor's data can only be accessed on the CPU. ```mojo from gpu.host import DeviceContext, HostBuffer from layout import Layout, LayoutTensor comptime dtype = DType.float32 var ctx = DeviceContext() var dev_buf = ctx.enqueue_create_host_buffer[dtype](8) comptime layout = Layout.row_major(4, 4) var tensor = LayoutTensor[dtype, layout](dev_buf) ``` **Constraints:** * Layout must be fully static. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/std/gpu/host/device_context/HostBuffer)): Contains the underlying data to point to. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref[origin] device_buffer: DeviceBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer` and a runtime layout. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the GPU. **Constraints:** * Element layout must be fully static. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/std/gpu/host/device_context/DeviceBuffer)): The `DeviceBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref[origin] host_buffer: HostBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer` and a runtime layout. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the CPU. **Constraints:** * Element layout must be fully static. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/std/gpu/host/device_context/HostBuffer)): The `HostBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref[origin] device_buffer: DeviceBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the GPU. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/std/gpu/host/device_context/DeviceBuffer)): The `DeviceBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref[origin] host_buffer: HostBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the CPU. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/std/gpu/host/device_context/HostBuffer)): The `HostBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `__getitem__` `__getitem__[*Tys: Indexer](self, *args: *Tys) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_type` Retrieves a single element from the tensor at the specified indices. This method provides array-like indexing for the tensor. The number of indices provided must match the rank of the tensor, otherwise an error will occur at runtime. **Parameters:** * ​\*Tys ([`Indexer`](/mojo/std/builtin/int/Indexer)): The type of the indices. Must implement the `Indexer` trait, and match the rank of the tensor. **Args:** * ​\*args (`*Tys`): The indices specifying the element's position in the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The element at the specified position with the tensor's data type. `__getitem__(self, crd: RuntimeTuple[S, element_type=element_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_type` Retrieves a single element from the tensor at the specified indices. This method provides array-like indexing for the tensor. The number of indices provided must match the rank of the tensor, otherwise an error will occur at runtime. **Args:** * ​crd ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The coordinate specifying the element's position in each dimension. For example, in a 3D tensor, you would use (i, j, k). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The element at the specified position with the tensor's data type. ### `__setitem__` `__setitem__[*Tys: Indexer](self, *args: *Tys, *, val: SIMD[dtype, LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_size])` Sets a single element in a tensor at the specified indices. This method provides array-like element assignment for tensors. Notes: * Bounds checking is NOT currently supported for `__setitem__` due to complications with certain layout types and mutation contexts. Use `__getitem__`, `load`, or `store` methods for bounds-checked access. In the future, this restriction will be lifted. **Parameters:** * ​\*Tys ([`Indexer`](/mojo/std/builtin/int/Indexer)): The type of the indices. Must implement the `Indexer` trait, and match the rank of the tensor. **Args:** * ​\*args (`*Tys`): The indices specifying the element's position in the tensor. * ​val ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to write to the tensor at the specified position. ### `__add__` `__add__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Add a scalar value to each element of the tensor. Performs an elementwise addition operation, adding the scalar value to each element in the tensor. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the addition. * For in-place addition, use the `__iadd__` method instead (`+=` operator). **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to add to each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the addition operation. `__add__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Add another tensor to this tensor elementwise. Performs an elementwise addition between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the addition. * For in-place addition, use the `__iadd__` method instead (`+=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to add to this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the addition operation. ### `__sub__` `__sub__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Subtract a scalar value from each element of the tensor. Performs an elementwise subtraction operation, subtracting the scalar value from each element in the tensor. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the subtraction. * For in-place subtraction, use the `__isub__` method instead (`-=` operator). **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to subtract from each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the subtraction operation. `__sub__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Subtract another tensor from this tensor elementwise. Performs an elementwise subtraction between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the subtraction. * For in-place subtraction, use the `__isub__` method instead (`-=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to subtract from this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the subtraction operation. ### `__mul__` `__mul__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Multiply each element of the tensor by a scalar value. Performs an elementwise multiplication operation, multiplying each element in the tensor by the scalar value. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the multiplication. * For in-place multiplication, use the `__imul__` method instead (`*=` operator). **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to multiply with each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the multiplication operation. `__mul__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Multiply this tensor with another tensor elementwise. Performs an elementwise multiplication (Hadamard product) between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Note: This is NOT a matrix multiplication operation. For matrix multiplication, use the appropriate matmul function instead. Performance: * This operation creates a copy of the tensor before performing the multiplication. * For in-place multiplication, use the `__imul__` method instead (`*=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to multiply with this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the elementwise multiplication. ### `__truediv__` `__truediv__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Divide each element of the tensor by a scalar value. Performs an elementwise division operation, dividing each element in the tensor by the scalar value. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the division. * For in-place division, use the `__itruediv__` method instead (`/=` operator). Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to divide each element by. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the division operation. `__truediv__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Divide this tensor by another tensor elementwise. Performs an elementwise division between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the division. * For in-place division, use the `__itruediv__` method instead (`/=` operator). Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to divide this tensor by. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the division operation. ### `__iadd__` `__iadd__(self, other: Scalar[dtype])` Add a scalar value to each element of the tensor in-place. Performs an elementwise addition operation, adding the scalar value to each element in the tensor. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to add to each element. `__iadd__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Add another tensor to this tensor elementwise in-place. Performs an elementwise addition between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to add to this tensor. ### `__isub__` `__isub__(self, other: Scalar[dtype])` Subtract a scalar value from each element of the tensor in-place. Performs an elementwise subtraction operation, subtracting the scalar value from each element in the tensor. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to subtract from each element. `__isub__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Subtract another tensor from this tensor elementwise in-place. Performs an elementwise subtraction between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to subtract from this tensor. ### `__imul__` `__imul__(self, other: Scalar[dtype])` Multiply each element of the tensor by a scalar value in-place. Performs an elementwise multiplication operation, multiplying each element in the tensor by the scalar value. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to multiply with each element. `__imul__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Multiply this tensor with another tensor elementwise in-place. Performs an elementwise multiplication (Hadamard product) between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Note: This is NOT a matrix multiplication operation. For matrix multiplication, use the appropriate matmul function instead. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to multiply with this tensor. ### `__itruediv__` `__itruediv__(self, other: Scalar[dtype])` Divide each element of the tensor by a scalar value in-place. Performs an elementwise division operation, dividing each element in the tensor by the scalar value. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Args:** * ​other ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to divide each element by. `__itruediv__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Divide this tensor by another tensor elementwise in-place. Performs an elementwise division between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to divide this tensor by. ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). **Returns:** `String`: The host type's name. ### `__merge_with__` `__merge_with__[other_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]]](self) -> LayoutTensor[dtype, layout, origin_of((mutcast origin._mlir_origin), (mutcast origin._mlir_origin)), address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Returns a tensor merged with the specified `other_type`. **Parameters:** * ​other\_type (`AnyStruct`): The type of the tensor to merge with. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor merged with the specified `other_type`. ### `bitcast` `bitcast[new_dtype: DType, /, target_address_space: AddressSpace = address_space, _element_layout: Layout = element_layout](self) -> LayoutTensor[new_dtype, layout, origin, address_space=target_address_space, element_layout=_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Bitcast the underlying pointer to a new data type. **Parameters:** * ​new\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The new data type it is casting to. * ​target\_address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The address space of the returned `LayoutTensor`. * ​\_element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout of the returned `LayoutTensor`. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` with the same memory location but with the specified data type, address space, and element layout. ### `as_any_origin` `as_any_origin(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Casts the origin of the mutable `LayoutTensor` to `MutAnyOrigin`. This requires the tensor to already be mutable as casting mutability is inherently very unsafe. It is usually preferred to maintain concrete origin values instead of using `MutAnyOrigin`. However, if it is needed, keep in mind that `MutAnyOrigin` can alias any memory value, so Mojo's ASAP destruction will not apply during the lifetime of the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A pointer with the origin set to `MutAnyOrigin`. `as_any_origin(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, ImmutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Casts the origin of the immutable `LayoutTensor` to `ImmutAnyOrigin`. It is usually preferred to maintain concrete origin values instead of using `ImmutAnyOrigin`. However, if it is needed, keep in mind that `ImmutAnyOrigin` can alias any memory value, so Mojo's ASAP destruction will not apply during the lifetime of the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor with the origin set to `ImmutAnyOrigin`. ### `address_space_cast` `address_space_cast[target_address_space: AddressSpace = address_space](self) -> LayoutTensor[dtype, layout, origin, address_space=target_address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Changes the address space of the `LayoutTensor`. **Parameters:** * ​target\_address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The new address space. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` object with the same type and origin as the original `LayoutTensor`, and the new specified address\_space. ### `get_immutable` `get_immutable(self) -> LayoutTensor[dtype, layout, origin_of((muttoimm origin._mlir_origin)), address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Return an immutable version of this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `LayoutTensor` covering the same elements, but without mutability. ### `ptr_at_offset` `ptr_at_offset(self, coords: IndexList[size, element_type=element_type]) -> LegacyUnsafePointer[Scalar[dtype], address_space=address_space]` Get a pointer offset at the given flattened coordinates. **Args:** * ​coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): A flattened list of the offset coordinates. **Returns:** [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer): A pointer offset at the given flattened coordinates. ### `__exp__` `__exp__(self) -> Self` Computes element-wise exponential function. Returns a new tensor containing the [element-wise exponential](/mojo/std/math/math/exp/) of the input tensor. **Returns:** `Self`: A new tensor containing the element-wise exponential. ### `load_scalar` `load_scalar[*Tys: Indexer](self, *args: *Tys) -> Scalar[dtype]` Retrieves a single scalar from the tensor at the specified indices. This method provides scalar element access for the tensor, which is useful in generic contexts where `__getitem__` returns a SIMD vector of `element_size` elements. This method always returns a single scalar value (the 0th lane of the element). The number of indices provided must match the rank of the tensor, otherwise an error will occur at runtime. **Parameters:** * ​\*Tys ([`Indexer`](/mojo/std/builtin/int/Indexer)): The type of the indices. Must implement the `Indexer` trait, and match the rank of the tensor. **Args:** * ​\*args (`*Tys`): The indices specifying the element's position in the tensor. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The scalar value at the specified position with the tensor's dtype. `load_scalar(self, crd: RuntimeTuple[S, element_type=element_type]) -> Scalar[dtype]` Retrieves a single scalar from the tensor at the specified coordinates. This method provides scalar element access for the tensor, which is useful in generic contexts where `__getitem__` returns a SIMD vector of `element_size` elements. This method always returns a single scalar value (the 0th lane of the element). **Args:** * ​crd ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The coordinate specifying the element's position in each dimension. For example, in a 3D tensor, you would use (i, j, k). **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The scalar value at the specified position with the tensor's dtype. ### `load` `load[width: Int, load_alignment: Int = alignment](self, m: Int, n: Int) -> SIMD[dtype, width]` Load a SIMD vector from the tensor at the specified 2D coordinates. Performs a vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use `aligned_load` instead when data alignment is guaranteed. * The load operation is optimized based on the tensor's memory layout. Notes: * Bounds checking is performed via debug\_assert for the base coordinate and the full SIMD width range. Enable assertions with `-D ASSERT=all` to catch out-of-bounds accesses during development. * The elements are loaded according to the tensor's stride configuration. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. * ​load\_alignment ([`Int`](/mojo/std/builtin/int/Int)): The alignment to use. Defaults to Self.alignment. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/std/builtin/int/Int)): The column index (second dimension). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. `load[width: Int, load_alignment: Int = alignment](self, coords: IndexList[size, element_type=element_type]) -> SIMD[dtype, width]` Load a SIMD vector from the tensor at the specified coordinates. Performs a vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at the position specified by `coords`. This method enables efficient SIMD operations on tensor data and works with tensors of any rank. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use `aligned_load` instead when data alignment is guaranteed. * The load operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are loaded according to the tensor's stride configuration. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. * ​load\_alignment ([`Int`](/mojo/std/builtin/int/Int)): The alignment to use. Defaults to Self.alignment. **Args:** * ​coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): The coordinates to index. Must have the same size as the tensor's rank. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. ### `prefetch` `prefetch(self, m: Int, n: Int)` Prefetch tensor data at the specified 2D coordinates into cache. Issues a software prefetch hint to the processor to load the data at position (m, n) into the cache hierarchy. This can improve performance by reducing memory latency for subsequent accesses to the same location. Performance: * Prefetching is a performance hint and does not guarantee data will be cached. * Most effective when issued sufficiently ahead of the actual data access. * Uses high locality prefetch to the data cache, optimized for data that will be accessed multiple times. * Can reduce memory access latency by 50-90% when used correctly. Notes: * Excessive prefetching can pollute the cache and degrade performance. * Most beneficial for predictable access patterns that would otherwise cause cache misses. * No operation is performed on the prefetched data. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/std/builtin/int/Int)): The column index (second dimension). `prefetch(self, coords: IndexList[size, element_type=element_type])` Prefetch tensor data at the specified coordinates into cache. Issues a software prefetch hint to the processor to load the data at coords into the cache hierarchy. This can improve performance by reducing memory latency for subsequent accesses to the same location. Performance: * Prefetching is a performance hint and does not guarantee data will be cached. * Most effective when issued sufficiently ahead of the actual data access. * Uses high locality prefetch to the data cache, optimized for data that will be accessed multiple times. * Can reduce memory access latency by 50-90% when used correctly. Notes: * Excessive prefetching can pollute the cache and degrade performance. * Most beneficial for predictable access patterns that would otherwise cause cache misses. * No operation is performed on the prefetched data. **Args:** * ​coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): The indices. ### `aligned_load` `aligned_load[width: Int](self, m: Int, n: Int) -> SIMD[dtype, width]` Load a SIMD vector with alignment guarantees from the tensor. Performs an aligned vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at position (m, n). The alignment is automatically calculated based on the SIMD width and dtype. Performance: * Uses aligned memory access which is faster than unaligned access on most architectures. * The alignment is automatically calculated based on the SIMD width and dtype. * Can be up to 2x faster than unaligned loads on architectures that require alignment. Notes: * The caller must ensure that the memory at (m, n) is properly aligned. Misaligned access with this method may cause hardware exceptions on some architectures. * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/std/builtin/int/Int)): The column index (second dimension). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. `aligned_load[width: Int](self, coords: IndexList[size, element_type=element_type]) -> SIMD[dtype, width]` Load a SIMD vector with alignment guarantees from the tensor. Performs an aligned vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at the position specified by `coords`. The alignment is automatically calculated based on the SIMD width and dtype. This method enables efficient SIMD operations on tensor data and works with tensors of any rank. Performance (copied from 'aligned\_load[width](m,n)'): * Uses aligned memory access which is faster than unaligned access on most architectures. * The alignment is automatically calculated based on the SIMD width and dtype. * Can be up to 2x faster than unaligned loads on architectures that require alignment. Notes: * The caller must ensure that the memory at the specified coordinates is properly aligned. Misaligned access with this method may cause hardware exceptions on some architectures. * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are loaded according to the tensor's stride configuration. * The last dimension must have unit stride (stride == 1) for this operation to be valid. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. **Args:** * ​coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): The coordinates to index. Must have the same size as the tensor's rank. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. ### `store` `store[width: Int, store_alignment: Int = alignment](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: Int, n: Int, val: SIMD[dtype, width])` Store a SIMD vector to the tensor at the specified 2D coordinates. Performs a vectorized store operation to the tensor's memory, writing 'width' consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use aligned\_store instead when data alignment is guaranteed. * The store operation is optimized based on the tensor's memory layout. Notes: * Bounds checking is performed via debug\_assert for the base coordinate and the full SIMD width range. Enable assertions with `-D ASSERT=all` to catch out-of-bounds accesses during development. * The elements are stored according to the tensor's stride configuration. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. * ​store\_alignment ([`Int`](/mojo/std/builtin/int/Int)): The alignment to use. Defaults to Self.alignment. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The row index (first dimension) where the store operation begins. * ​n ([`Int`](/mojo/std/builtin/int/Int)): The column index (second dimension) where the store operation begins. * ​val ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. `store[width: Int, store_alignment: Int = alignment](self, coords: IndexList[size, element_type=element_type], val: SIMD[dtype, width])` Store a SIMD vector to the tensor at the specified ND coordinates. Performs a vectorized store operation to the tensor's memory, writing 'width' consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use aligned\_store instead when data alignment is guaranteed. * The store operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are stored according to the tensor's stride configuration. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. * ​store\_alignment ([`Int`](/mojo/std/builtin/int/Int)): The alignment to use. Defaults to Self.alignment. **Args:** * ​coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): The coordinates to index. * ​val ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. ### `aligned_store` `aligned_store[width: Int](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: Int, n: Int, val: SIMD[dtype, width])` Store a SIMD vector with alignment guarantees to the tensor. Performs an aligned vectorized store operation to the tensor's memory, writing `width` consecutive elements starting at position (m, n). The alignment is automatically calculated based on the SIMD width and dtype. Performance: * Uses aligned memory access which is faster than unaligned access on most architectures. * The alignment is automatically calculated based on the SIMD width and dtype. * Can be up to 2x faster than unaligned stores on architectures that require alignment. * Particularly important for streaming stores that bypass the cache. Notes: * The caller must ensure that the memory at (m, n) is properly aligned. Misaligned access with this method may cause hardware exceptions on some architectures. * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The row index (first dimension) where the store operation begins. * ​n ([`Int`](/mojo/std/builtin/int/Int)): The column index (second dimension) where the store operation begins. * ​val ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. ### `size` `size(self) -> Int` Get the total number of elements that the tensor can contain. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The total number of elements that can be stores in the tensor. ### `stack_allocation` `static stack_allocation[*, stack_alignment: Int = alignment]() -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].StackTensorType` Allocates stack memory for a `LayoutTensor` with a fully static layout. Creates a new `LayoutTensor` instance with memory allocated on the stack rather than the heap. This provides deterministic memory management and potentially better performance for tensors with known sizes at compile time. Performance: * Stack allocation is typically faster than heap allocation. * Proper alignment can significantly improve memory access performance, especially for vectorized operations. * No dynamic memory management overhead (no malloc/free calls). Notes: * Only works with tensors that have fully static layouts known at compile time. * Stack memory is limited, so this should only be used for reasonably sized tensors. * The allocated memory is automatically freed when the function returns. **Constraints:** * The layout must be fully static (all dimensions known at compile time). * The alignment must be a multiple of the tensor's minimum required alignment. **Parameters:** * ​stack\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment value for the allocation in bytes. Must be a multiple of the tensor's minimum required alignment. Default is the tensor's natural alignment based on its data type and layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` instance with memory allocated on the stack. ### `null` `static null() -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].StackTensorType` Returns a null `LayoutTensor` object. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A null `LayoutTensor` object. ### `to_device_buffer` `to_device_buffer(self, ctx: DeviceContext) -> DeviceBuffer[dtype]` Convert the tensor to a `DeviceBuffer`. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context to use. **Returns:** [`DeviceBuffer`](/mojo/std/gpu/host/device_context/DeviceBuffer): A `DeviceBuffer` containing the tensor's data. ### `is_static_shape` `static is_static_shape[idx: Int]() -> Bool` Returns the whether the specified dimension is statically known. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. Notes: * This is a static method that operates on the tensor's type information, not on a specific tensor instance. **Parameters:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape \[10, UNKNOWN\_VALUE, 30]: \- `shape[0]()` returns True (first dimension). \- `shape[1]()` returns False (second dimension). \- `shape[2]()` returns True (third dimension). **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): The True if the dimension is statically known, False otherwise. ### `shape` `static shape[idx: Int]() -> Int` Returns the size of the tensor along the specified dimension. Provides static access to the tensor's shape information. This method returns the size of a specific dimension without requiring an instance of the tensor, as the shape is part of the tensor's static type information. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. Notes: * This is a static method that operates on the tensor's type information, not on a specific tensor instance. **Parameters:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape \[10, 20, 30]: * `shape[0]()` returns 10 (first dimension). * `shape[1]()` returns 20 (second dimension). * `shape[2]()` returns 30 (third dimension). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor along the specified dimension as an integer. ### `get_shape` `get_shape(self) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Get the flattened shape of a LayoutTensor. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The flattened shape of a LayoutTensor. ### `get_stride` `get_stride(self) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Get the flattened stride of a LayoutTensor. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The flattened shape of a LayoutTensor. ### `stride` `static stride[idx: Int]() -> Int` Returns the memory stride of the tensor along the specified dimension. Provides static access to the tensor's stride information. The stride represents the number of elements to skip in memory to move one position along a particular dimension. This method returns the stride without requiring an instance of the tensor, as the stride is part of the tensor's static type information. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. * Understanding stride patterns is crucial for optimizing memory access patterns in performance-critical code. Notes: * Strides depend on the memory layout (row-major, column-major, or custom). * For non-contiguous tensors (e.g., tensor slices), strides may not follow a simple pattern. **Parameters:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a 2D tensor with shape \[10, 20] and row-major layout: * `stride[0]()` might return 20 (moving one row requires skipping 20 elements). * `stride[1]()` might return 1 (moving one column requires skipping 1 element). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The memory stride of the tensor along the specified dimension as an integer. `stride(self, idx: Int) -> Int` Returns the runtime stride of the tensor along the specified axis. Unlike the static `stride` method, this instance method takes a runtime dimension index. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a row-major 3D tensor with shape `[10, 20, 30]`: * `stride(0)` returns 600 (first dimension). * `stride(1)` returns 30 (second dimension). * `stride(2)` returns 1 (third dimension). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The dimension of the tensor along the specified axis as an integer. ### `dim` `dim(self, idx: Int) -> Int` Returns the runtime dimension size of the tensor along the specified axis. Unlike the static `dim` method, this instance method takes a runtime dimension index. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape `[10, 20, 30]`: * `dim(0)` returns 10 (first dimension). * `dim(1)` returns 20 (second dimension). * `dim(2)` returns 30 (third dimension). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The dimension of the tensor along the specified axis as an integer. `dim[idx: Int](self) -> Int` Returns the dimension size of the tensor along the specified axis. Unlike the static `shape` method, this instance method provides access to the tensor's actual dimension sizes. If the dimension is unknown, the runtime layout is used to get the dimension size. Performance: * For static dimensions known at compile time, prefer the static `shape` method when possible for better performance. Notes: * This method works with both static and dynamic dimensions. * For tensors with masked or partial views, this returns the actual size of the view, not the original tensor. **Constraints:** * Only works with tensors that have depth-1 layouts (no nested shapes). **Parameters:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape `[10, 20, 30]`: * `dim[0]()` returns 10 (first dimension). * `dim[1]()` returns 20 (second dimension). * `dim[2]()` returns 30 (third dimension). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the tensor along the specified dimension as an integer. ### `coalesce` `coalesce(self) -> LayoutTensor[dtype, coalesce(layout, False), origin, address_space=address_space, element_layout=element_layout]` Creates a tensor with a coalesced memory layout from this tensor. Coalescing a tensor's layout means reorganizing its memory representation to be as contiguous as possible, which can improve memory access patterns and performance. This operation does not move or copy data; it only changes how the same memory is interpreted. Performance: * Coalesced layouts typically provide better cache utilization and memory access patterns. * This operation is zero-cost at runtime as it only changes the layout information, not the actual data. * Particularly beneficial before operations that perform sequential memory access or vectorized operations. Notes: * The coalesced tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The shape of the tensor remains the same, only the stride information is optimized. * For already optimally coalesced tensors, this operation has no effect. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor with the same data but with a coalesced memory layout. The returned tensor has type `LayoutTensor` with the same dtype but with a coalesced layout. ### `tile` `tile[*tile_sizes: Int](self, *tile_coords: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment]` Extract a tile (sub-tensor) from this tensor with specified dimensions and position. Tiling is a fundamental operation for high-performance tensor computations that divides a tensor into smaller blocks for better cache locality and parallelism. This method extracts a specific tile at the given coordinates without copying data. Example: For a 4x4 tensor with values: ``` [1 2 3 4] [2 3 4 5] [5 4 3 2] [1 1 1 1] ``` `tile[2, 2](1, 0)` will extract the tile: ``` [5 4] [1 1] ``` Performance: * Creates a view without copying data, making it very efficient. * Optimized for both static and dynamic layouts with different code paths. * Properly handles edge cases where tiles may be partially outside the tensor. * Maintains stride information for efficient memory access within the tile. Notes: * The resulting tile is a view into the original tensor, so modifications to the tile will affect the original tensor. * For tiles at the edges of the tensor, the actual dimensions may be smaller than the requested tile\_sizes if masking is enabled. * The implementation automatically selects between static and dynamic tiling based on the tensor's layout properties. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. For example, in a 2D tensor, `tile[32, 32]` creates 32x32 tiles. **Args:** * ​\*tile\_coords ([`Int`](/mojo/std/builtin/int/Int)): The coordinates of the specific tile to extract. For example, `tile[32, 32](1, 2)` extracts the tile at position (1, 2) in the grid of 32x32 tiles. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the specified tile. ### `simd_tile` `simd_tile[tile_size: Int](self, tile_idx: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size, simd_width_of[dtype]()]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_size, simd_width_of[dtype]()](), alignment=alignment]` Return a SIMD\[dtype] sized tile of size `tile_size` at `tile_idx`. **Parameters:** * ​tile\_size ([`Int`](/mojo/std/builtin/int/Int)): The size of the tile along the tiled axis used for vectorization. **Args:** * ​tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the tile to extract along the tiled axis. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A SIMD\[dtype] tile of size `tile_size` at `tile_idx` ### `tile_with_offset` `tile_with_offset[*tile_sizes: Int](self, *tile_coords: Int) -> Tuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment], LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].CornerCoordsType, Scalar[linear_idx_type]]` Similar to `tile`, but also returns the corner coordinates of the tile as well as the offset. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. **Args:** * ​\*tile\_coords ([`Int`](/mojo/std/builtin/int/Int)): The coordinates of the specific tile to extract. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple): A tuple containing: * The extracted tile as a `LayoutTensor`. * The corner coordinates of the tile. * The offset of the tile. ### `tiled_iterator` `tiled_iterator[*tile_sizes: Int, *, axis: Int = 0](self, *tile_coords: Int) -> LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes]()]` Create an iterator that traverses tiles along a specified axis. This method creates an iterator that allows efficient traversal of tiles within a tensor. The iterator starts at the specified tile coordinates and can move along the specified axis, providing access to consecutive tiles. Performance: * Provides efficient sequential access to tiles with good cache locality. * Optimized for both static and dynamic layouts with different code paths. * Maintains stride information for efficient memory access within each tile. * Properly handles edge cases where tiles may be partially outside the tensor. Notes: * The iterator provides views into the original tensor, so modifications through the iterator will affect the original tensor. * For tiles at the edges of the tensor, the actual dimensions may be smaller than the requested tile\_sizes if masking is enabled. * The iterator is not circular by default, meaning it will not wrap around when reaching the end of the tensor along the iteration axis. * The implementation automatically selects between static and dynamic tiling based on the tensor's layout properties. Example: ```mojo var iter = tensor.tiled_iterator[16, 16, axis=0](0, 0) for i in range(num_tiles_along_axis): var tile = iter.get() # Process tile iter.next() ``` **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. For example, in a 2D tensor, `tiled_iterator[32, 32]` creates an iterator over 32x32 tiles. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which the iterator will traverse. Default is 0 (first dimension). For example, with axis=0, the iterator will move vertically through tiles. **Args:** * ​\*tile\_coords ([`Int`](/mojo/std/builtin/int/Int)): The starting coordinates of the tile where iteration begins. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A `LayoutTensorIter` that can be used to traverse tiles along the specified axis. ### `split` `split[count: Int, axis: Int = 0](self) -> StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size=(layout.shape[axis].value() // count), axis=axis]()[0], MutAnyOrigin, address_space=address_space, element_layout=element_layout, alignment=alignment], count]` Split the `LayoutTensor` along a axis and return a `StaticTuple` of `LayoutTensor`. **Parameters:** * ​count ([`Int`](/mojo/std/builtin/int/Int)): Number of portion to split. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis where the split is applied to. **Returns:** `StaticTuple`: A `StaticTuple` containing `count` `LayoutTensors`, each representing an equal-sized partition of the original tensor along the specified axis. Each partition has the same data type and memory characteristics as the original tensor, but with a reduced size along the split axis. `split[axis: Int = 0, split_alignment: Int = 1](self, count: Int, idx: Int) -> LayoutTensor[dtype, layout.make_shape_unknown[axis](), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Retrieve a specific partition of the tensor after splitting along a specified axis. This method divides the tensor into 'count' partitions along the specified axis and returns the partition at index 'idx'. The partitioning is done with alignment considerations to optimize memory access patterns. Unlike the overloaded split method that returns all partitions, this method returns only a single partition, making it more memory-efficient for cases where only one partition is needed at a time. Notes: * The shape along the split axis becomes unknown at compile time. * Only works with dimensions that have statically known sizes. * The last partition may be smaller than others if the dimension size is not evenly divisible by `count`. * Partition sizes are aligned up to the specified alignment value, which can improve performance for vectorized operations. Performance: * Uses aligned partitioning to improve memory access patterns. * Avoids creating all partitions in memory, reducing memory usage. * Maintains the original tensor's stride information for efficient element access within the partition. **Constraints:** * The dimension being split must have a statically known size. * Cannot split dimensions with unknown or dynamic sizes. **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis along which to split the tensor. Defaults to 0 (first dimension). * ​split\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment value for the partition size. Defaults to 1. **Args:** * ​count ([`Int`](/mojo/std/builtin/int/Int)): The number of partitions to divide the tensor into. * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the partition to return (0-based). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `LayoutTensor` representing the requested partition. ### `distribute` `distribute[threads_layout: Layout, axis: Optional[Int] = None, swizzle: Optional[Swizzle] = None, submode_axis: Optional[Int] = None](self, thread_id: Scalar[DType.uint]) -> LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False]` Distribute tensor workload across multiple threads in a structured pattern. This method partitions a tensor across multiple threads for parallel processing, assigning each thread a specific portion of the tensor. The distribution pattern is determined by the threads\_layout parameter, which defines the logical arrangement of threads. Example: For a 4x4 row-major tensor distributed across 4 threads in a 2x2 row-major grid: * Thread 0 will receive a LayoutTensor with a view into (0,0), (0,2), (2,0), (2,2) of the original tensor. * Thread 1 will receive a LayoutTensor with a view into (0,1), (0,3), (2,1), (2,3) of the original tensor. * Thread 2 will receive a LayoutTensor with a view into (1,0), (1,2), (3,0), (3,2) of the original tensor. * Thread 3 will receive a LayoutTensor with a view into (1,1), (1,3), (3,1), (3,3) of the original tensor. If axis=0 is specified with the same setup: * Thread (0, 0) and Thread (0, 1) would get the same data (top half) * Thread (1, 0) and Thread (1, 1) would get the same data (bottom half) Performance: * Creates a view without copying data, making it very efficient for parallel processing. * The swizzle parameter can significantly improve cache locality and memory access patterns. * Optimized for both static and dynamic layouts with different code paths. Notes: * The resulting tensor is a view into the original tensor, so modifications will affect the original tensor. * For optimal performance, the `threads_layout` should match the hardware's thread organization (e.g., warp/wavefront size and shape). * When using swizzling, carefully consider the memory access patterns to avoid cache thrashing or bank conflicts. * This function is particularly useful for GPU programming where threads are organized in structured grids. **Constraints:** * For dynamic layouts, the shape must be known at runtime and the threads\_layout must be fully static. **Parameters:** * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Defines the logical arrangement of threads (e.g., 2x2 grid of 4 threads). This layout determines how the tensor is partitioned. * ​axis ([`Optional`](/mojo/std/collections/optional/Optional)): Optional. If specified, restricts distribution to only this axis. For example, with axis=0 in a 2D thread layout, threads that differ only in their second coordinate will receive the same data. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional. A function that remaps the distribution pattern to improve memory access patterns or cache locality. * ​submode\_axis ([`Optional`](/mojo/std/collections/optional/Optional)): Optional. Specifies an axis for specialized distribution modes. **Args:** * ​thread\_id ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The ID of the current thread (0-based). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the portion assigned to this thread. ### `distribute_with_offset` `distribute_with_offset[threads_layout: Layout, axis: Optional[Int] = None, swizzle: Optional[Swizzle] = None, submode_axis: Optional[Int] = None](self, thread_id: Scalar[DType.uint]) -> Tuple[LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False], IndexList[threads_layout.rank(), element_type=layout_int_type], Scalar[linear_idx_type]]` Similar to `distribute`, but also returns the corner coordinates of the tile as well as the offset. **Parameters:** * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the threads. * ​axis ([`Optional`](/mojo/std/collections/optional/Optional)): The axis to distribute along. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): An optional swizzle function. * ​submode\_axis ([`Optional`](/mojo/std/collections/optional/Optional)): An optional submode axis. **Args:** * ​thread\_id ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The ID of the current thread (0-based). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple): A tuple containing: * The distributed tensor. * The corner coordinates of the tile. * The offset of the tile. ### `vectorize` `vectorize[*vector_shape: Int](self) -> LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reshape a tensor into a vectorized form for efficient SIMD operations. This method transforms the tensor's logical layout to enable efficient vectorized processing, treating blocks of elements as vector units. The transformation is particularly useful for SIMD (Single Instruction Multiple Data) operations and hardware acceleration. Example: For a 16x16 tensor, `vectorize[4, 4]` will produce a 4x4 tensor where each element represents a 4x4 block from the original tensor. Performance: * Creates a view without copying data, making it very efficient. * Enables hardware-accelerated vector operations on blocks of data. * Improves cache locality by grouping related elements together. * Particularly beneficial for operations that can leverage SIMD instructions. Notes: * The tensor dimensions must be divisible by the corresponding vector dimensions. * For dimensions with unknown size, the corresponding vector dimension must be 1. * The resulting tensor has the same data but a different logical organization. * Modifications to the vectorized tensor affect the original tensor. * This transformation is particularly useful for GPU and vector processor optimizations. **Constraints:** * Each tensor dimension must be divisible by the corresponding vector dimension. * Vector dimensions must be smaller than or equal to the corresponding tensor dimensions. * For dimensions with unknown size, the vector dimension must be 1. **Parameters:** * ​\*vector\_shape ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of each vector unit along each axis of the tensor. or example, in a 2D tensor, `vectorize[4, 4]` treats 4x4 blocks as vector units. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with a vectorized layout, where each element in the resulting tensor represents a vector of elements from the original tensor. `vectorize(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].SIMDVectorizedType` Return a SIMD\[dtype] vectorized view of this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `Self.VectorizedType[1, simd_width_of[Self.dtype]()]` view whose width equals the SIMD width for the tensor's dtype. ### `slice` `slice[d0_slice: Slice, d1_slice: Slice](self) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a slice from a rank-2 tensor using slice objects. This method creates a view into a subset of the tensor defined by the slice specifications for each dimension. The slice is a continuous region of the tensor with no gaps (step size must be 1). Example: For a 4x4 tensor, `t` with values: ``` [1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16] ``` ```mojo t.slice[Slice(1, 3), Slice(0, 2)] ``` will extract: ``` [5 6] [9 10] ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * Only supports rank-2 tensors. For higher-rank tensors, use the overloaded version with slice indices. * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. **Constraints:** * Only works with rank-2 tensors. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the first dimension (rows). Defines the start and end indices for the slice along this dimension. * ​d1\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the second dimension (columns). Defines the start and end indices for the slice along this dimension. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the specified slice. `slice[d0_slice: Slice, d1_slice: Slice, slice_indices: IndexList[2], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)](self, offsets: IndexList[__offset_dims]) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice, slice_indices.__getitem__[2, DType.int64, Int](0), slice_indices.__getitem__[2, DType.int64, Int](1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a 2D slice from a higher-rank tensor at specific indices. This method creates a view into a 2D subset of a higher-rank tensor: Selecting two dimensions to slice using the slice\_indices parameter. Applying slice specifications to those dimensions. Using fixed offsets for all other dimensions. Example: Given a 3x4x5 tensor, `t`, the following example extracts a 2x2 slice from dimensions 0 and 2, with dimension 1 fixed at index 1. ```mojo t.slice = t.slice[Slice(1, 3), Slice(0, 2), IndexList[2](0, 2)](1) ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * The slice indices must be ordered (e.g., \[0, 2] is valid, \[2, 0] is not). * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. **Constraints:** * Slice step size must be 1 (no gaps). * Slice indices must be ordered (ascending). * Tensor rank must be at least 2. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the first selected dimension. * ​d1\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the second selected dimension. * ​slice\_indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): Indices of the two dimensions to slice (must be ordered). * ​\_\_offset\_dims ([`Int`](/mojo/std/builtin/int/Int)): Internal parameter representing number of fixed dimensions. **Args:** * ​offsets ([`IndexList`](/mojo/std/utils/index_/IndexList)): Fixed index values for all dimensions not being sliced. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A 2D view into the original tensor representing the specified slice. ### `slice_1d` `slice_1d[d0_slice: Slice, slice_indices: IndexList[1], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 1)](self, offsets: IndexList[__offset_dims]) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, slice_indices.__getitem__[1, DType.int64, Int](0)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a 1D slice from a higher-rank tensor at a specific index. This method creates a view into a 1D subset of a higher-rank tensor by: 1. Selecting one dimension to slice using the slice\_indices parameter 2. Applying a slice specification to that dimension 3. Using fixed offsets for all other dimensions Example: For a 3x4x5 tensor, `t`, the following example extracts a 1D slice from dimension 0, with dimensions 1 and 2 fixed at indices 1 and 2: ```mojo t.slice_1d[Slice(1, 3), IndexList[1](0)](1, 2) ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. * This function exists as a workaround for compiler limitations with overloading. **Constraints:** * Slice step size must be 1 (no gaps). * Tensor rank must be at least 1. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/std/builtin/builtin_slice/Slice)): Slice specification for the selected dimension. * ​slice\_indices ([`IndexList`](/mojo/std/utils/index_/IndexList)): Index of the dimension to slice. * ​\_\_offset\_dims ([`Int`](/mojo/std/builtin/int/Int)): Internal parameter representing number of fixed dimensions. **Args:** * ​offsets ([`IndexList`](/mojo/std/utils/index_/IndexList)): Fixed index values for all dimensions not being sliced. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A 1D view into the original tensor representing the specified slice. ### `transpose` `transpose(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].TransposeType` Create a transposed view of a tensor. This method creates a view of the tensor with its dimensions swapped, effectively converting rows to columns and columns to rows. The transposition is performed without copying data, by adjusting the tensor's layout information. Example: For a 2x3 tensor with values: ``` [1 2 3] [4 5 6] ``` `transpose()` will produce a 3x2 tensor: ``` [1 4] [2 5] [3 6] ``` Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may be less efficient in the transposed view due to non-contiguous memory access, especially for row-major storage. Notes: * The transposed tensor shares the same memory as the original tensor, so modifications to one will affect the other. * For optimal performance when repeatedly accessing the transposed data, consider creating a physical copy with the transposed layout. * Transpose only works with statically known shapes. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with dimensions transposed (rows become columns and vice versa). ### `reshape` `reshape[dst_layout: Layout](self) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Create a view of the tensor with a different shape. This method creates a view of the tensor with a new shape, without changing the underlying data. The total number of elements must remain the same. Example: Given a 2x6 row-major tensor, `reshape[Layout.col_major(3, 4)]()` produces a 3x4 tensor with the same elements in column-major order. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may change, potentially affecting performance depending on the original and target layouts. Notes: * The reshaped tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The total number of elements must remain the same after reshaping. * The reshape operation assumes a row-major (C-style) memory layout. * For tensors with complex strides or non-contiguous memory, reshaping may not produce the expected results. * Masked tensors cannot be reshaped. **Constraints:** * Cannot reshape masked tensors. * The total number of elements must be the same in both layouts. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. Must have the same total number of elements as the original tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the new shape specified by dst\_layout. `reshape[dst_layout: Layout](self, runtime_layout: RuntimeLayout[dst_layout]) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Create a view of the tensor with a different shape. This method creates a view of the tensor with a new shape, without changing the underlying data. The total number of elements must remain the same. Example: Given a 2x6 row-major tensor, `reshape[Layout.col_major(3, 4)]()` produces a 3x4 tensor with the same elements in column-major order. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may change, potentially affecting performance depending on the original and target layouts. Notes: * The reshaped tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The total number of elements must remain the same after reshaping. * The reshape operation assumes a row-major (C-style) memory layout. * For tensors with complex strides or non-contiguous memory, reshaping may not produce the expected results. * Masked tensors cannot be reshaped. **Constraints:** * Cannot reshape masked tensors. * The total number of elements must be the same in both layouts. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. Must have the same total number of elements as the original tensor. **Args:** * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The target RuntimeLayout for the reshaped tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the new shape specified by dst\_layout. ### `flatten` `flatten(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].FlattenedType` Convert a LayoutTensor to a flattened dynamic layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A LayoutTensor to a flattened dynamic layout. ### `composition` `composition[rhs_layout: Layout, dst_layout: Layout = composition(layout, rhs_layout)](self) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Create a view of the tensor with a composed layout. This method creates a view of the tensor with a new layout that is the composition of the original layout with another layout. Layout composition allows for complex transformations of the tensor's logical structure without copying data. Example: For a 4x4 tensor with a standard row-major layout, composing with a layout that represents a 2x2 tiling would result in a tensor that logically views the data as 2x2 blocks. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Can be used to optimize memory access patterns for specific algorithms. Notes: * The composed tensor shares the same memory as the original tensor, so modifications to one will affect the other. * Layout composition is a powerful tool for expressing complex data transformations like tiling, transposition, and reshaping in a unified framework. * Understanding the mathematical properties of layout composition is important for correctly using this function. **Constraints:** * The layouts must be compatible for composition. * The total number of elements must remain the same after composition. **Parameters:** * ​rhs\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to compose with the tensor's current layout. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The resulting layout after composition. Defaults to the composition of the tensor's layout with rhs\_layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the composed layout. ### `distance` `distance(self, addr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin]) -> Scalar[linear_idx_type]` Calculate the element-wise distance between this tensor's pointer and another pointer. This method computes the number of elements (not bytes) between the tensor's pointer and the provided address. This is useful for determining offsets within a larger memory allocation or for pointer arithmetic operations. Example: If `tensor.ptr` points to an element at index 100 in a buffer, and `addr` points to element at index 50, then `distance(addr)` returns 50. Performance: * This is a lightweight operation that only involves pointer arithmetic. * The operation is optimized based on the address space, using smaller integer types for shared memory to improve efficiency. Notes: * The distance is calculated in elements, not bytes. * The result can be positive or negative depending on the relative positions of the pointers. * This function is particularly useful for GPU programming where understanding memory offsets is critical for performance. * Care should be taken when using this with pointers from different allocations, as the result would be meaningless. **Args:** * ​addr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The target pointer to calculate the distance to. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The number of elements between this tensor's pointer and the provided address. The result is of type `_uint_dtype`. `distance[_layout: Layout, _uint_dtype: DType = _get_unsigned_type(_layout, address_space)](self, src: LayoutTensor[dtype, _layout, origin, address_space=address_space]) -> Scalar[_uint_dtype]` Calculate the element-wise distance between this tensor and another tensor. This method computes the number of elements (not bytes) between this tensor's pointer and another tensor's pointer. This is useful for determining the relative positions of tensors within a larger memory allocation. Example: If tensor1 points to element at index 100 in a buffer, and tensor2 points to element at index 50, then `tensor1.distance(tensor2)` would return 50. Performance: * This is a lightweight operation that only involves pointer arithmetic. * The operation is optimized based on the address space and layout, using appropriate integer types for efficiency. Notes: * The distance is calculated in elements, not bytes. * The result can be positive or negative depending on the relative positions of the tensors. * This function is particularly useful for GPU programming where understanding memory offsets is critical for performance. * Both tensors must be in the same address space for the result to be meaningful. * This overload is more type-safe than the pointer-based version as it ensures the tensors have compatible data types and address spaces. **Parameters:** * ​\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the source tensor. * ​\_uint\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The unsigned integer type to use for the result. Automatically determined based on the layout and address space. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to calculate the distance to. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The number of elements between this tensor's pointer and the source tensor's pointer. The result is of type \_uint\_dtype. ### `copy_from` `copy_from(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], other: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Copy data from another tensor to this tensor. This method performs an element-by-element copy from the source tensor to this tensor, respecting the layouts of both tensors. The copy operation handles different memory layouts correctly, ensuring that elements are copied to their proper positions regardless of how the data is arranged in memory. * Both tensors must have statically known shapes. * The total number of elements must be the same in both tensors. * The element sizes must match between the tensors. Example: ```mojo from layout import LayoutTensor, Layout var src_storage = InlineArray[Float32, 2 * 3](uninitialized=True) var dst_storage = InlineArray[Float32, 3 * 2](uninitialized=True) var src = LayoutTensor[ DType.float32, Layout([2, 3]), ](src_storage).fill(1.0) var dst = LayoutTensor[ DType.float32, Layout([3, 2]), ](dst_storage) dst.copy_from(src) # Copies all elements from src to dst ``` Performance: * Performs element-by-element copying, which may be less efficient than vectorized or bulk memory operations. * The copy respects the memory layout of both tensors, which may involve non-contiguous memory access patterns. * For optimal performance with large tensors, consider using specialized copy functions that can leverage hardware acceleration. Notes: * Both tensors must have statically known shapes. * The total number of elements must be the same in both tensors. * The element sizes must match between the tensors. * This function handles different memory layouts correctly, making it suitable for copying between tensors with different shapes or strides. * The copy is performed element by element, not as a bulk memory copy. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to copy data from. Must have the same total number of elements as this tensor. ### `copy_from_async` `copy_from_async[is_masked: Bool = False, swizzle: Optional[Swizzle] = None, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_idx_bound: Scalar[linear_idx_type] = 0, base_offset: Scalar[linear_idx_type] = 0)` Asynchronously copy data from another tensor to this tensor using GPU hardware. This method performs an asynchronous copy from the source tensor to this tensor using GPU hardware acceleration. It's specifically designed for copying data from global memory to shared memory in GPU kernels, leveraging hardware-specific asynchronous copy mechanisms for improved performance. For optimal performance, you need to arrange the copy correctly. Use the [`distribute()`](/mojo/kernels/layout/layout_tensor/LayoutTensor/#distribute) method to create thread-local fragments of the source and destination tensors, assigning each thread one or more elements to copy. Optionally, use the [`vectorize()`](/mojo/kernels/layout/layout_tensor/LayoutTensor/#vectorize) method to get vectorized views of both tensors before calling `distribute()`. This allows each thread to copy multiple elements of the tensor. For example: ```mojo var fragment = tensor.vectorize[1, simd_width]().distribute[ thread_layout ](thread_id) ``` The copy operation is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/std/gpu/memory/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/std/gpu/memory/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. Example: ```mojo from layout import LayoutTensor, Layout from gpu import thread_idx, block_idx, block_dim from gpu.memory import async_copy_wait_all comptime dtype = DType.float32 comptime in_size = 128 comptime block_size = 16 num_blocks = in_size // block_size comptime input_layout = Layout.row_major(in_size, in_size) fn kernel(tensor: LayoutTensor[dtype, input_layout, MutAnyOrigin]): # extract a tile from the input tensor. var global_tile = tensor.tile[block_size, block_size](block_idx.x, block_idx.y) # allocate a shared memory tile comptime tile_layout = Layout.row_major(block_size, block_size) var shared_tile = LayoutTensor[ dtype, tile_layout, MutAnyOrigin, address_space = AddressSpace.SHARED, ].stack_allocation() # Create per-thread tile fragments for copying var tid = thread_idx.y + thread_idx.x * block_dim.x comptime thread_layout = Layout.row_major(block_size, block_size) var global_fragment = global_tile.distribute[thread_layout](tid) var shared_fragment = shared_tile.distribute[thread_layout](tid) # async copy to shared memory shared_fragment.copy_from_async(global_fragment) async_copy_wait_all() # ... do something with the shared tile ``` Performance: * Supports vectorized copies for 4, 8, or 16-byte elements for better throughput. * Can bypass L1 cache with appropriate eviction policies for specific access patterns. * Swizzling can improve memory access patterns and reduce bank conflicts. Notes: * For vectorized copies, both tensors must have contiguous element layouts. * Asynchronous copies allow computation to overlap with memory transfers. * A synchronization barrier is required before using the copied data. **Constraints:** * Destination must be in shared memory. * Source and destination data types must match. * Element size must be 4, 8, or 16 bytes. * Destination tensor must have a static layout. **Parameters:** * ​is\_masked ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to perform a masked copy, where elements outside the `src_idx_bound` are not copied or filled with zeros. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns. * ​fill ([`Fill`](/mojo/std/gpu/memory/memory/Fill)): Fill policy for elements that are not copied (only used with masked copies). * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the source data. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to copy data from. * ​src\_idx\_bound ([`Scalar`](/mojo/std/builtin/simd/#scalar)): For masked copies, the upper bound index for valid source elements. * ​base\_offset ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Base offset for swizzling calculations. ### `fill` `fill[*, use_runtime_layout: Bool = layout.all_dims_known().__bool__().__invert__() if layout.all_dims_known().__bool__().__invert__()._mlir_value else (layout.size() > 2048)](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], val: Scalar[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Fill the entire tensor with a single value. This method sets all elements of the tensor to the specified value. It works with both statically and dynamically shaped tensors. For statically known layouts, the fill operation is unrolled at compile time. For dynamic layouts, a runtime loop is used. No vectorization is applied, so performance may be suboptimal for large tensors. Consider using hardware-specific fill operations for better performance with large tensors. This method can be used with tensors of any rank and shape. The fill operation respects the tensor's layout, filling all elements regardless of how they are arranged in memory. For tensors with `element_layout`, all elements within each logical element are filled with the same value. Example: ```mojo from layout import Layout, LayoutTensor def main(): var storage = InlineArray[Float32, 3 * 4](uninitialized=True) var tensor = LayoutTensor[ DType.float32, Layout([3, 4]), ](storage).fill(0.0) print(tensor) ``` If not using method chaining, you can either reassign the result to the tensor variable, or assign the result to the discard pattern (`_`) to avoid warnings about an unused value: ```mojo tensor = tensor.fill(0.0) # or _ = tensor.fill(0.0) ``` **Parameters:** * ​use\_runtime\_layout ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use the runtime layout for filling. This parameter is defaulted to `True` if the layout is not statically known. If loop bounds are too large, it's better to use the runtime layout to avoid long compilation time. **Args:** * ​val ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The value to fill the tensor with. Must be of the same data type as the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The tensor itself (self), allowing for method chaining. ### `__str__` `__str__(self) -> String` Convert the tensor to a string representation. This method converts the tensor to a human-readable string representation by writing its contents to a string. It delegates to the `write_to` method which formats the tensor appropriately based on its rank and shape. **Returns:** `String`: A string representation of the tensor. ### `write_to` `write_to(self, mut writer: T)` Format and write the tensor's contents to a writer. This method formats the tensor's contents and writes them to the provided writer. For 2D tensors, it formats the output in a 2D grid. For tensors of other ranks, it prints all values in column-major coordinate order. Example: ```mojo from layout import Layout, LayoutTensor def main(): var storage = InlineArray[Float32, 2 * 3](uninitialized=True) var tensor = LayoutTensor[ DType.float32, Layout([2, 3]), ](storage).fill(1.0) print(tensor) # Internally calls `write_to` with a StringWriter ``` Output for a 2x3 tensor: ``` [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] ``` Notes: * For 2D tensors, the output is formatted as a 2D grid with rows and columns. * For tensors of other ranks, values are printed in column-major coordinate order. * Empty tensors (size 0) produce no output. * This method is used by the `__str__` method to convert the tensor to a string. * The formatting is designed for human readability rather than parsing. * For large tensors, the output may be truncated to avoid excessive output. **Args:** * ​writer (`T`): The writer instance to write the formatted output to.
--- ## LayoutTensorIter
`@register_passable(trivial)` `struct LayoutTensorIter[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut=mut], /, *, address_space: AddressSpace = AddressSpace.GENERIC, alignment: Int = align_of[dtype](), circular: Bool = False, axis: Optional[Int] = None, layout_int_type: DType = _get_index_type(address_space), linear_idx_type: DType = _get_index_type(address_space), masked: Bool = False]` Iterator for traversing a memory buffer with a specific layout. `LayoutTensorIter` provides a way to iterate through memory according to a specific layout pattern, constructing layout tensors at each position. This enables efficient traversal of multi-dimensional data structures with custom memory layouts. Notes: The returned layout tensor is NOT vectorized. Users should explicitly vectorize if needed for performance-critical operations. ## Parameters * ​mut ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the iterator allows mutation of the underlying data. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout pattern to follow during iteration. * ​origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): Origin tracking for memory safety. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The memory address space (`GLOBAL`, `SHARED`, etc.). * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment requirement for the data. * ​circular ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether iteration wraps around at boundaries. * ​axis ([`Optional`](/mojo/std/collections/optional/Optional)): Optional axis for dimension-specific operations. * ​layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Integer type used for layout indices. * ​linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Integer type used for indexing into memory. * ​masked ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to apply bounds masking during iteration. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin]`): Pointer to the memory region being iterated, with appropriate type and memory attributes. * ​offset (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Current offset from the base pointer, representing the iterator's position in memory. * ​stride (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Step size between consecutive elements or blocks in memory during iteration. * ​bound (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Upper bound of the memory region, limiting the iteration range. * ​runtime\_layout (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].RuntimeLayoutType`): Runtime representation of the layout pattern used for mapping logical indices to memory locations. * ​dimension\_bound (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].layout_uint_type`): Boundary value for the current dimension when iterating along a specific axis. * ​idx (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Current logical index position within the iteration sequence. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BitcastType` `comptime BitcastType[new_type: DType, *, address_space: AddressSpace = address_space, alignment: Int = alignment] = LayoutTensorIter[new_type, layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for bitcast iterator types. #### Parameters * ​new\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The target data type. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The target address space. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): The target memory alignment. ### `layout_uint_type` `comptime layout_uint_type = Scalar[_unsigned_integral_type_of[layout_int_type]()]` The unsigned integer type used for layout, based on layout and address space. ### `LayoutTensorType` `comptime LayoutTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` The LayoutTensor type returned by this iterator. ### `linear_uint_type` `comptime linear_uint_type = Scalar[_unsigned_integral_type_of[linear_idx_type]()]` The unsigned integer type used for indexing into memory. ### `ReshapeType` `comptime ReshapeType[dst_layout: Layout] = LayoutTensorIter[dtype, dst_layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for reshaped iterator types. #### Parameters * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped iterator. ### `RuntimeLayoutType` `comptime RuntimeLayoutType = RuntimeLayout[layout, element_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for the runtime layout. ## Methods ### `__init__` `__init__() -> Self` Initialize an empty iterator. Creates a default iterator with zero values, typically used as a placeholder or default value. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin], bound: Scalar[_unsigned_integral_type_of[linear_idx_type]()], stride: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = layout.size(), offset: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0) -> Self` Initialize an iterator with a pointer and basic parameters. Creates an iterator for a memory region with the specified bounds and stride. **Constraints:** The layout must have all dimensions known at compile time. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to the beginning of the memory region. * ​bound ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Upper bound of the memory region. * ​stride ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Step size between consecutive elements (defaults to layout size). * ​offset ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Initial offset from the base pointer. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin], bound: Int) -> Self` Initialize an iterator with a pointer and `Int` bound. Creates an iterator for a memory region with the specified bounds and stride. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to the beginning of the memory region. * ​bound ([`Int`](/mojo/std/builtin/int/Int)): Upper bound of the memory region. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin], bound: Scalar[_unsigned_integral_type_of[linear_idx_type]()], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], stride: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = layout.size() if layout.all_dims_known() else -1, offset: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0, dimension_bound: Scalar[_unsigned_integral_type_of[layout_int_type]()] = 0, idx: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0) -> Self` Initialize an iterator with a runtime layout. Creates an iterator with a runtime-determined layout, allowing for more flexible memory traversal patterns. **Constraints:** The runtime layout must have the same bitwidth as specified for the iterator. Circular iteration is not supported when an axis is defined. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to the beginning of the memory region. * ​bound ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Upper bound of the memory region. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): Layout determined at runtime. * ​stride ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Step size between consecutive elements. * ​offset ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Initial offset from the base pointer. * ​dimension\_bound ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Bound for the specified dimension when using masked iteration. * ​idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Initial index position. ### `__getitem__` `__getitem__(self) -> LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].LayoutTensorType` Get the layout tensor at the current iterator position. Operator overload that returns a layout tensor representing the data at the current position of the iterator. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A layout tensor at the current iterator position. ### `__iadd__` `__iadd__[T: Intable](mut self, rhs: T)` Increment the iterator by an integer value. Advances the iterator by the specified number of positions. Notes: This function is unsafe. It omits bound checking for performance reasons. Caller must ensure the index doesn't go out-of-bounds. **Parameters:** * ​T ([`Intable`](/mojo/std/builtin/int/Intable)): A type that can be converted to an integer. **Args:** * ​rhs (`T`): The number of positions to advance. `__iadd__(mut self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()])` Increment the iterator by a uint value. Advances the iterator by the specified number of positions. Notes: This function is unsafe. It omits bound checking for performance reasons. Caller must ensure the index doesn't go out-of-bounds. **Args:** * ​rhs ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The number of positions to advance. ### `get` `get(self) -> LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].LayoutTensorType` Get the layout tensor at the current iterator position. Returns a layout tensor representing the data at the current position of the iterator. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A tensor view at the current iterator position with the same type, layout, and memory characteristics as specified by the output parameter. ### `next` `next[T: Intable](self, rhs: T) -> Self` Return an iterator pointing to a position ahead by rhs steps. Creates a new iterator that points rhs positions ahead of the current one. **Parameters:** * ​T ([`Intable`](/mojo/std/builtin/int/Intable)): An integer-convertible type for the step size. **Args:** * ​rhs (`T`): The number of positions to advance. **Returns:** `Self`: A new iterator pointing to the advanced position. `next(self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 1) -> Self` Return an iterator pointing to a position ahead by rhs steps. Creates a new iterator that points rhs positions ahead of the current one. **Args:** * ​rhs ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The number of positions to advance (defaults to 1). **Returns:** `Self`: A new iterator pointing to the advanced position. ### `next_unsafe` `next_unsafe(self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 1) -> Self` Return an iterator pointing to a position ahead by rhs steps (unsafe version). Creates a new iterator that points rhs positions ahead of the current one. This is an unsafe version that omits certain checks for performance. **Constraints:** Cannot be used with masked iterators. User must ensure rhs < bound / stride. **Args:** * ​rhs ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The number of positions to advance (defaults to 1). **Returns:** `Self`: A new iterator pointing to the advanced position. ### `reshape` `reshape[dst_layout: Layout](self) -> LayoutTensorIter[dtype, dst_layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reshape the iterator to a new layout. This method creates a new iterator with a different layout while preserving the underlying data. The new layout must have the same total size as the original. **Constraints:** * The destination layout must have the same total size as the original. * Both layouts must be contiguous. * Both layouts must have compile-time known dimensions. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout to reshape to. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A new iterator with the specified layout. ### `bitcast` `bitcast[new_type: DType, *, target_address_space: AddressSpace = address_space, target_alignment: Int = alignment](self) -> LayoutTensorIter[new_type, layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reinterpret the iterator's underlying pointer as a different data type. This method performs a bitcast operation, allowing you to view the same memory location as a different data type without copying or converting the data. **Parameters:** * ​new\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The target data type to cast to. * ​target\_address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The memory address space for the new iterator (defaults to current). * ​target\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment requirement for the new iterator (defaults to current). **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A new LayoutTensorIter with the same layout but different data type.
--- ## ThreadScope
`@register_passable(trivial)` `struct ThreadScope` Represents the scope of thread operations in GPU programming. This struct defines the scope at which thread operations are performed, particularly for operations like tensor distribution and synchronization. It provides two main scopes: `BLOCK` and `WARP`, which correspond to different levels of thread grouping in GPU programming models. Example: ```mojo from layout.layout_tensor import copy_dram_to_sram, ThreadScope # Distribute tensor at block level (all threads in block participate) copy_dram_to_sram[layout, thread_scope=ThreadScope.BLOCK](dst, src) # Distribute tensor at warp level (only threads in same warp participate) copy_dram_to_sram[layout, thread_scope=ThreadScope.WARP](dst, src) ``` Performance: * WARP scope operations typically have lower synchronization overhead than BLOCK scope operations. * BLOCK scope operations allow coordination across all threads in a block, which is necessary for certain algorithms. * The choice of scope can significantly impact performance and correctness of parallel algorithms. Notes: * The appropriate scope depends on the specific algorithm and hardware. * WARP scope operations may be more efficient for operations that only require coordination within a warp. * BLOCK scope operations are necessary when threads from different warps need to coordinate. * The actual size of a warp or block is hardware-dependent. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BLOCK` `comptime BLOCK = ThreadScope(0)` Represents operations at the thread block level, where all threads in a block participate. ### `WARP` `comptime WARP = ThreadScope(1)` Represents operations at the warp level, where only threads within the same warp participate. ## Methods ### `__init__` `__init__(value: Int) -> Self` Initialize a `ThreadScope` with the given integer value. **Args:** * ​value ([`Int`](/mojo/std/builtin/int/Int)): An integer representing the thread scope (0 for `BLOCK`, 1 for `WARP`). ### `__eq__` `__eq__(self, other: Self) -> Bool` Compare two `ThreadScope` objects for equality. **Args:** * ​other (`Self`): Another `ThreadScope` object to compare with. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the thread scopes are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Compare two `ThreadScope` objects for inequality. **Args:** * ​other (`Self`): Another `ThreadScope` object to compare with. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the thread scopes are not equal, False otherwise. ### `__str__` `__str__(self) -> String` Convert the `ThreadScope` to a human-readable string representation. Aborts: If the thread scope has an invalid value. **Returns:** `String`: A string representation of the thread scope ("BLOCK" or "WARP"). ### `__int__` `__int__(self) -> Int` Convert the `ThreadScope` to an integer value. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value of the thread scope (0 for BLOCK, 1 for WARP).
--- ## copy_dram_to_local
`copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, cache_policy: CacheOperation = CacheOperation.ALWAYS](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: Optional[UInt] = None)` Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_load intrinsic to efficiently transfer data from global memory to registers while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * The offset calculation method significantly impacts performance. Current implementation optimizes for throughput over flexibility. * This function is particularly useful for prefetching data into registers before performing computations, reducing memory access latency. **Constraints:** * Only supported on AMD GPUs. * The destination element layout size must match the SIMD width. * Source fragments must be rank 2 with known dimensions. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. * ​cache\_policy ([`CacheOperation`](/mojo/std/gpu/memory/memory/CacheOperation)): The cache policy to use for the copy operation. Defaults to `CacheOperation.ALWAYS`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in global memory (DRAM) to be copied. * ​src\_base ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The original global memory tensor from which src is derived. This is used to construct the buffer struct required by AMD's `buffer_load` intrinsic. * ​offset ([`Optional`](/mojo/std/collections/optional/Optional)): The offset in the global memory. `copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, cache_policy: CacheOperation = CacheOperation.ALWAYS](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bounds: UInt32)` Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_load intrinsic to efficiently transfer data from global memory to registers while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * The offset calculation method significantly impacts performance. Current implementation optimizes for throughput over flexibility. * This function is particularly useful for prefetching data into registers before performing computations, reducing memory access latency. **Constraints:** * Only supported on AMD GPUs. * The destination element layout size must match the SIMD width. * Source fragments must be rank 2 with known dimensions. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. * ​cache\_policy ([`CacheOperation`](/mojo/std/gpu/memory/memory/CacheOperation)): The cache policy to use for the copy operation. Defaults to `CacheOperation.ALWAYS`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator. * ​bounds ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Bounds of the buffer, based on the ptr of the src\_iter. `copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from global memory (DRAM) to registers. This function implements an optimized memory transfer operation from global memory to register memory. It distributes the copy operation across multiple threads for maximum throughput while handling bounds checking for safety. **Constraints:** * The source tensor must be in GLOBAL address space (DRAM). * The destination tensor must be in LOCAL address space (registers). * Both tensors must have compatible data types. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in global memory (DRAM).
--- ## copy_dram_to_sram
`copy_dram_to_sram[src_thread_layout: Layout, dst_thread_layout: Layout = src_thread_layout, swizzle: Optional[Swizzle] = None, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. This function performs a synchronous copy operation from global memory (DRAM) to shared memory (SRAM) in a GPU context, distributing the workload across multiple threads for parallel execution. It uses thread affinity mapping to ensure efficient work distribution and supports vectorized memory operations for optimal performance. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Thread affinity mapping ensures efficient work distribution. * For masked tensors, performs bounds checking to handle edge cases correctly. Notes: * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. * For optimal performance, the thread layouts should match the memory access patterns of the tensors. * This function is particularly useful in GPU kernels for loading data from global memory to shared memory for faster access. **Constraints:** * Source and destination tensors must have the same data type. * Source tensor must be in GENERIC or GLOBAL address space. * Destination tensor must be in SHARED address space. * For non-masked tensors, the fragment sizes must match. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the destination tensor. Defaults to the same as `src_thread_layout` if not specified. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `ThreadScope.BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM). `copy_dram_to_sram[src_thread_layout: Layout, dst_thread_layout: Layout = src_thread_layout, swizzle: Optional[Swizzle] = None, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bound: Int)` Efficiently copy data from global memory (DRAM) to shared memory (SRAM) on AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's `buffer_load` intrinsic to efficiently transfer data while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. Defaults to the same layout as `src_thread_layout`. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling pattern to apply when distributing the destination tensor. This can improve memory access patterns and reduce bank conflicts. Defaults to None (no swizzling). * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory (SRAM). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator in global memory (DRAM) to be copied. * ​bound ([`Int`](/mojo/std/builtin/int/Int)): The bound of the source tensor iterator. `copy_dram_to_sram[thread_layout: Layout, swizzle: Optional[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bound: Int)` Synchronously copy data from DRAM to SRAM using a unified thread layout for AMD GPUs. This is a convenience wrapper around the more general `copy_dram_to_sram()` function that uses the same layout for both source and destination tensors. It's specifically designed for AMD GPUs where the buffer\_load intrinsic requires the original base tensor. Performance: * Simplifies API usage when the same thread layout is appropriate for both source and destination tensors. * Optimized for AMD GPUs using buffer\_load intrinsics for efficient memory transfers. * Distributes the copy workload across multiple threads for parallel execution. Notes: * This function is only supported on AMD GPUs. * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator, which must be in global or generic memory (DRAM). * ​bound ([`Int`](/mojo/std/builtin/int/Int)): The bound of the source tensor iterator. `copy_dram_to_sram[thread_layout: Layout, swizzle: Optional[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from DRAM to SRAM using a unified thread layout. This is a convenience wrapper around the more general `copy_dram_to_sram()` function that uses the same layout for both source and destination tensors. It simplifies the API for the common case where the same thread distribution pattern works well for both tensors. Performance: * Simplifies API usage when the same thread layout is appropriate for both source and destination tensors. * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. Notes: * The source tensor must be in `GENERIC` or `GLOBAL` address space (DRAM). * The destination tensor must be in `SHARED` address space (SRAM). * Both tensors must have the same data type. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `ThreadScope.BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM).
--- ## copy_dram_to_sram_async
`copy_dram_to_sram_async[src_thread_layout: Layout, dst_thread_layout: Layout, swizzle: Bool = False, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_threads: Int = src_thread_layout.size(), block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. This function performs an asynchronous copy operation from global memory (DRAM) to shared memory (SRAM) in a GPU context, using NVIDIA's cp.async hardware mechanism. It distributes the workload across multiple threads and allows computation to overlap with memory transfers for improved performance. Performance: * Performs asynchronous transfers, allowing computation to overlap with memory operations. * Distributes the copy workload across multiple threads for parallel execution. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Supports different cache eviction policies to optimize memory hierarchy usage. * For masked tensors, performs bounds checking to handle edge cases correctly. Notes: * This function requires NVIDIA GPUs with `cp.async` support (compute capability 8.0+). * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/std/gpu/memory/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/std/gpu/memory/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. * The maximum size of each element that can be copied is 16 bytes. **Constraints:** * Requires NVIDIA GPUs with cp.async support (compute capability 8.0+). * Source tensor must be in `GENERIC` or `GLOBAL` address space. * Destination tensor must be in `SHARED` address space. * Both tensors must have the same data type. * Element size must be 4, 8, or 16 bytes. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the destination tensor. * ​swizzle ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to apply swizzling to the destination indices to reduce bank conflicts. Defaults to False. * ​fill ([`Fill`](/mojo/std/gpu/memory/memory/Fill)): Fill policy for handling out-of-bounds accesses. Options include: * `Fill.NONE`: No special handling (default). * `Fill.ZERO`: Fill out-of-bounds elements with zeros. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the source data. Options include: * `CacheEviction.EVICT_NORMAL`: Normal eviction (default). * `CacheEviction.EVICT_FIRST`: Evict data after first use. * `CacheEviction.EVICT_LAST`: Keep data in cache until last use. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `src_thread_layout.size()` will be disabled and not participate in the copy operation. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM). `copy_dram_to_sram_async[thread_layout: Layout, swizzle: Bool = False, masked: Bool = False, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_threads: Int = thread_layout.size(), block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronous copy from DRAM to SRAM with thread affinity mapping. This function performs an asynchronous memory transfer from DRAM (global memory) to SRAM (shared memory) using the specified thread layout for distribution. Notes: This is a convenience wrapper around the more general `copy_dram_to_sram_async()` function, using the same thread layout for both source and destination. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute work across threads. * ​swizzle ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to apply memory access swizzling for better performance. * ​masked ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the copy operation should use masking. * ​fill ([`Fill`](/mojo/std/gpu/memory/memory/Fill)): Fill policy for uninitialized memory regions. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy to use during the transfer. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `thread_layout.size()` will be disabled and not participate in the copy operation. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in SRAM. * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor in DRAM.
--- ## copy_local_to_dram
`copy_local_to_dram[dst_thread_layout: Layout, num_threads: Int = dst_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from registers (LOCAL) to global memory (DRAM). This function implements a high-performance memory transfer operation from register memory to global memory. It distributes the copy operation across multiple threads for maximum throughput while handling bounds checking for safety. **Constraints:** * The source tensor must be in LOCAL address space (registers). * The destination tensor must be in GENERIC or GLOBAL address space (DRAM). * Both tensors must have compatible data types. **Parameters:** * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `dst_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in global memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in register memory (LOCAL) to be copied. `copy_local_to_dram[dst_thread_layout: Layout, num_threads: Int = dst_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from registers (LOCAL) to global memory (DRAM) on AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_store intrinsic to efficiently transfer data from registers to global memory while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * This function is particularly useful for writing computed results from registers back to global memory with minimal latency. * The offset calculation is optimized for performance rather than flexibility. **Constraints:** * Only supported on AMD GPUs. * Destination tensor must be in GLOBAL address space. * Source tensor must be in LOCAL address space. * Data types must match between source and destination tensors. **Parameters:** * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `dst_thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in global memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in register memory (LOCAL address space) to be copied. * ​dst\_base ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The original global memory tensor from which dst is derived. This is used to construct the buffer descriptor required by AMD's `buffer_store` intrinsic.
--- ## copy_local_to_local
`copy_local_to_local(dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data between local memory (register) tensors with type conversion. This function performs a synchronous copy operation between register tensors in a GPU context, with support for converting from float32 to half-precision formats (bfloat16/float16). It's particularly optimized for specific tensor layouts commonly used in matrix multiplication operations. Example: ```mojo from layout import LayoutTensor, Layout from layout.layout_tensor import copy_local_to_local fn kernel(): ... var src_reg = LayoutTensor[DType.float32, Layout.row_major(16, 8), MutAnyOrigin, address_space = AddressSpace.LOCAL, ].stack_allocation().fill(1) var dst_reg = LayoutTensor[DType.bfloat16, Layout.row_major(16, 8), MutAnyOrigin, address_space = AddressSpace.LOCAL, ].stack_allocation() # Process data in float32 registers # ... # Convert and copy to bfloat16 registers copy_local_to_local(dst_reg, src_reg) ``` Performance: * Optimized for specific 2D tensor layouts with contiguous inner dimensions. * Special fast path for 2D tensors with specific layouts used in matrix multiplication. * For MMA (Matrix Multiply-Accumulate) operations, efficiently handles the conversion between output fragments and input fragments with different layouts. * Falls back to element-wise copy for general cases. Notes: * Both source and destination tensors must be in `LOCAL` address space (registers). * This function currently only supports copying from float32 to half-precision formats. * For 2D tensors with stride\[1] == 1, a specialized fast path is used that's optimized for matrix multiplication patterns. * This function is particularly useful in GPU kernels for converting between different precision formats while keeping data in registers. **Constraints:** * Destination tensor must be in `LOCAL` address space. * Source tensor must be in `LOCAL` address space. * Destination tensor must have a half-precision floating-point data type. * Source tensor must have float32 data type. * Both tensors must have the same total size. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in local memory (registers) and have a half-precision floating-point data type (bfloat16 or float16). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in local memory (registers) and have float32 data type.
--- ## copy_local_to_shared
`copy_local_to_shared[thread_layout: Layout, swizzle: Optional[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, *, row_major: Bool = False](dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from local memory (registers) to SRAM (shared memory). This function performs a synchronous copy operation from register memory to shared memory in a GPU context, distributing the workload across multiple threads for parallel execution. It's particularly useful for transferring processed data from registers to shared memory for inter-thread communication. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Optimized for transferring data from registers to shared memory. * On AMD GPUs, the `row_major` parameter can be used to match the memory access pattern used during prefetching from DRAM to registers. Notes: * The destination tensor must be in `SHARED` address space (SRAM). * The source tensor must be in `LOCAL` address space (registers). * This function is particularly useful in GPU kernels for sharing processed data between threads in the same block. * The `row_major` parameter is specifically designed for AMD GPUs when using a prefetching pattern from DRAM to SRAM via registers. **Constraints:** * Destination tensor must be in SHARED address space. * Source tensor must be in LOCAL address space. * For optimal performance, the thread layout should match the memory access patterns of the tensors. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the operation. This determines how the workload is distributed among threads. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `thread_layout.size()` will be disabled and not participate in the copy operation. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. * ​row\_major ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use row-major ordering for the copy operation. This is particularly relevant when prefetching from DRAM to SRAM via registers on AMD GPUs. Defaults to False. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in local memory (registers).
--- ## copy_sram_to_dram
`copy_sram_to_dram[thread_layout: Layout, swizzle: Optional[Swizzle] = None, num_threads: Int = thread_layout.size(), block_dim_count: Int = 1, binary_op: Optional[binary_op_type] = None](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from SRAM (shared memory) to DRAM (global memory). This function performs a synchronous memory transfer from SRAM (shared memory) to DRAM (global memory) using the specified thread layout for workload distribution. It supports optional swizzling for optimized memory access patterns and binary operations for combining data during the transfer. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns. * Supports binary operations to combine data during transfer (e.g., for reduction operations). Notes: * The source tensor must be in `SHARED` address space (SRAM). * The destination tensor must be in `GENERIC` or `GLOBAL` address space (DRAM). * Supports FP32 to half-precision downcast during copy if needed. * Handles masked tensors with proper bounds checking. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. **Constraints:** * Source tensor must be in SHARED address space with a static layout. * Destination tensor must be in GENERIC or GLOBAL address space. * For type conversion, only FP32 to half-precision is supported. * For vectorized copy with type conversion, both tensors must have element layouts matching the SIMD width of the destination type. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzling function to rearrange the source indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total number of threads in the thread block. Threads beyond `thread_layout.size()` will be disabled and not participate in the copy operation. * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the thread block. * ​binary\_op ([`Optional`](/mojo/std/collections/optional/Optional)): Optional binary operation to apply during the copy, combining source data with existing destination data. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in global or generic memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in shared memory (SRAM).
--- ## copy_sram_to_local
`copy_sram_to_local[src_warp_layout: Layout, axis: Optional[Int] = None](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from SRAM (shared memory) to local memory. This function performs a synchronous memory transfer from SRAM (shared memory) to local memory (registers) using the specified thread layout for workload distribution. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Optimized for transferring data from shared memory to registers. * Supports optional axis-specific distribution for specialized access patterns. **Constraints:** * The source tensor must be in SHARED address space (SRAM). * The destination tensor must be in LOCAL address space (registers). * Both tensors must have the same data type. **Parameters:** * ​src\_warp\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​axis ([`Optional`](/mojo/std/collections/optional/Optional)): Optional parameter specifying which axis to distribute along. When provided, distribution happens along the specified axis. When None (default), distribution uses the standard layout pattern. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in local memory (registers). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in shared memory (SRAM).
--- ## cp_async_k_major
`cp_async_k_major[dtype: DType, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronously copy data from DRAM to SRAM using TMA (Tensor Memory Accelerator) with K-major layout. This function performs an asynchronous copy operation from global memory (DRAM) to shared memory (SRAM) using NVIDIA's Tensor Memory Accelerator (TMA) hardware. It optimizes for K-major memory access patterns, which is particularly beneficial for certain tensor operations like matrix multiplications where the inner dimension (K) is accessed contiguously. The function automatically determines the optimal tile size and thread distribution based on the tensor shapes and hardware capabilities, leveraging TMA's efficient memory transfer mechanisms. Performance: * Uses TMA hardware acceleration for optimal memory transfer performance. * Optimizes for K-major access patterns, which can significantly improve performance for certain tensor operations like matrix multiplications. * Performs asynchronous transfers, allowing computation to overlap with memory operations. * Automatically determines optimal tile sizes based on tensor dimensions. * Uses hardware-accelerated swizzling to reduce shared memory bank conflicts. Notes: * This function requires NVIDIA GPUs with TMA support (compute capability 9.0+). * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/std/gpu/memory/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/std/gpu/memory/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. * K-major layout is particularly beneficial for matrix multiplication operations where the inner dimension (K) is accessed contiguously. **Constraints:** * Requires NVIDIA GPUs with TMA support (compute capability 9.0+). * Source tensor must be in GENERIC or GLOBAL address space. * Destination tensor must be in SHARED address space. * Both tensors must have the same data type. * Source and destination tensors must be 2D. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): The cache eviction policy to use. Default is `CacheEviction.EVICT_NORMAL`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM).
--- ## layout_tensor
Provides the `LayoutTensor` type for representing multidimensional data. ## `comptime` values ### `binary_op_type` `comptime binary_op_type = fn[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[dtype, width]` Type alias for binary operations on SIMD vectors. This type represents a function that takes two SIMD vectors of the same type and width and returns a SIMD vector of the same type and width. Args: dtype: The data type of the SIMD vector elements. width: The width of the SIMD vector. lhs: Left-hand side SIMD vector operand. rhs: Right-hand side SIMD vector operand. Returns: A SIMD vector containing the result of the binary operation. ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` Legacy OpaquePointer migration helper. ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` Legacy OpaquePointer migration helper. ## Structs * [​`LayoutTensor`](./LayoutTensor): A high-performance tensor with explicit memory layout and hardware-optimized access patterns. * [​`LayoutTensorIter`](./LayoutTensorIter): Iterator for traversing a memory buffer with a specific layout. * [​`ThreadScope`](./ThreadScope): Represents the scope of thread operations in GPU programming. ## Functions * [​`copy_dram_to_local`](./copy_dram_to_local): Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. * [​`copy_dram_to_sram`](./copy_dram_to_sram): Synchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. * [​`copy_dram_to_sram_async`](./copy_dram_to_sram_async): Asynchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. * [​`copy_local_to_dram`](./copy_local_to_dram): Efficiently copy data from registers (LOCAL) to global memory (DRAM). * [​`copy_local_to_local`](./copy_local_to_local): Synchronously copy data between local memory (register) tensors with type conversion. * [​`copy_local_to_shared`](./copy_local_to_shared): Synchronously copy data from local memory (registers) to SRAM (shared memory). * [​`copy_sram_to_dram`](./copy_sram_to_dram): Synchronously copy data from SRAM (shared memory) to DRAM (global memory). * [​`copy_sram_to_local`](./copy_sram_to_local): Synchronously copy data from SRAM (shared memory) to local memory. * [​`cp_async_k_major`](./cp_async_k_major): Asynchronously copy data from DRAM to SRAM using TMA (Tensor Memory Accelerator) with K-major layout. * [​`stack_allocation_like`](./stack_allocation_like): Create a stack-allocated tensor with the same layout as an existing tensor.
--- ## stack_allocation_like
`stack_allocation_like[layout: Layout, dtype: DType, *, address_space: AddressSpace, target_address_space: AddressSpace = AddressSpace.GENERIC](in_tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=target_address_space, masked=masked]` Create a stack-allocated tensor with the same layout as an existing tensor. This function creates a new tensor on the stack with the same layout, data type, and masking properties as the input tensor, but potentially with a different address space. This is useful for creating temporary tensors that match the structure of existing tensors. Example: ```mojo from layout import LayoutTensor, Layout from layout.layout_tensor import stack_allocation_like var global_tensor = LayoutTensor[ DType.float32, Layout([10, 10]), MutAnyOrigin, address_space=AddressSpace.GLOBAL ].stack_allocation() var shared_tensor = stack_allocation_like[ target_address_space=AddressSpace.SHARED ](global_tensor) ``` Performance: * Creates a tensor on the stack, which is typically faster to allocate and access than heap-allocated memory. * Stack allocations have automatic lifetime management, reducing memory management overhead. * Stack size is limited, so be cautious with large tensor allocations. Notes: * The new tensor will have the same layout, data type, and masking properties as the input tensor. * The address space can be changed, which is useful for moving data between different memory regions (e.g., from global to shared memory). * Stack allocations are automatically freed when they go out of scope. * The function uses the stack\_allocation method of the result tensor type. **Parameters:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the tensor to allocate. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The address space of the input tensor. * ​target\_address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): The address space for the new tensor. Defaults to GENERIC. **Args:** * ​in\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to match the layout of. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor allocated on the stack with the same layout as the input tensor.
--- ## math
Implements math methods that work on layout tensors. ## Functions * [​`max`](./max): Computes maximum reduction along specified axis. * [​`mean`](./mean): Computes the mean value of the elements in a buffer. * [​`outer_product_acc`](./outer_product_acc): Updates result tensor with the outer product of two vectors. * [​`sum`](./sum): Computes sum reduction along specified axis. * [​`variance`](./variance): Computes the variance value of the elements in a buffer.
--- ## max (Math)
`max[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], outp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes maximum reduction along specified axis. Reduces the input tensor by taking maximum elements along the specified axis and stores the result in the output tensor. **Constraints:** All tensors must have statically known shapes. `outp.rank` must equal `inp.rank - 1`. Non-reduction dimensions must match between `inp` and `outp`. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis to take maximum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to reduce. * ​outp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor to store maximum results. `max[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, _reduce_res_row_major_shape(axis, layout), MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Computes maximum reduction along specified axis, returning a new tensor. Reduces the input tensor by taking maximum elements along the specified axis and returns a new tensor with the results. **Constraints:** All tensors must have statically known shapes. Result will have rank equal to `inp.rank` - 1. Non-reduction dimensions in the result match the input. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis to take maximum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to reduce. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the maximum values along the specified axis. `max[dtype: DType, layout: Layout](x: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], y: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].MutableAnyType` Computes element-wise maximum of two tensors. Returns a new tensor containing the element-wise maximum between the input tensors. **Constraints:** Input tensors must have statically known shapes and matching layouts. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the input tensors. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the input tensors. **Args:** * ​x ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): First input tensor. * ​y ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Second input tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the element-wise maximum.
--- ## mean
`mean(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Scalar[dtype]` Computes the mean value of the elements in a buffer. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The buffer of elements for which the mean is computed. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The mean value of the elements in the given buffer. **Raises:** May raise on GPU targets when a device error occurs. `mean[reduce_axis: Int](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes the mean across reduce\_axis of a LayoutTensor. **Parameters:** * ​reduce\_axis ([`Int`](/mojo/std/builtin/int/Int)): The axis to reduce across. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. **Raises:** May raise on GPU targets when a device error occurs. `mean(src: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]) -> Scalar[dtype]` Computes the mean value of the elements in a buffer. **Args:** * ​src ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The buffer of elements for which the mean is computed. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The mean value of the elements in the given buffer. **Raises:** May raise on GPU targets when a device error occurs.
--- ## outer_product_acc
`outer_product_acc(res: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lhs: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], rhs: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Updates result tensor with the outer product of two vectors. Computes `res += outer(lhs, rhs)` where `lhs` and `rhs` are vectors and `res` is a matrix. **Constraints:** All tensors must have statically known shapes. `res` must be rank 2. `lhs` and `rhs` must be rank 1. `res.shape[0]` `==` `lhs.shape[0]` and `res.shape[1]` `==` `rhs.shape[0]`. **Args:** * ​res ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The result matrix to accumulate into, shape (M, N). * ​lhs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The left-hand side vector, shape (M,). * ​rhs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The right-hand side vector, shape (N,).
--- ## sum (Math)
`sum[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], outp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes sum reduction along specified axis. Reduces the input tensor by summing elements along the specified axis and stores the result in the output tensor. Example: ```mojo from layout import LayoutTensor, Layout from layout.math import sum data: InlineArray[Int32, 6] = [0, 1, 2, 3, 4, 5] tensor = LayoutTensor[DType.int32, Layout.row_major(2, 3)](data) print(tensor) print("-----") print(sum[0](tensor)) ``` Output: ```plaintext 0 1 2 3 4 5 ----- 3 5 7 ``` **Constraints:** All tensors must have statically known shapes. `outp.rank` must equal `inp.rank - 1`. Non-reduction dimensions must match between inp and outp. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis to sum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to sum. * ​outp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor to store sum results. `sum[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, _reduce_res_row_major_shape(axis, layout), MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Computes sum reduction along specified axis, returning a new tensor. Reduces the input tensor by summing elements along the specified axis and returns a new tensor with the results. **Constraints:** All tensors must have statically known shapes. Result will have rank equal to `inp.rank` - 1. Non-reduction dimensions in the result match the input. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis to sum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to sum. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the sum values along the specified axis.
--- ## variance
`variance(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], correction: Int = 1) -> Scalar[dtype]` Computes the variance value of the elements in a buffer. ``` variance(x) = sum((x - E(x))^2) / (size - correction) ``` **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The buffer. * ​correction ([`Int`](/mojo/std/builtin/int/Int)): Normalize variance by size - correction (Default=1). **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The variance value of the elements in a buffer. **Raises:** May raise on GPU targets when a device error occurs. `variance(src: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], correction: Int = 1) -> Scalar[dtype]` Computes the variance value of the elements in a buffer. ``` variance(x) = sum((x - E(x))^2) / (size - correction) ``` **Args:** * ​src ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The buffer. * ​correction ([`Int`](/mojo/std/builtin/int/Int)): Normalize variance by size - correction (Default=1). **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The variance value of the elements in a buffer. **Raises:** May raise on GPU targets when a device error occurs.
--- ## RuntimeLayout
`@register_passable(trivial)` `struct RuntimeLayout[layout: Layout, /, *, element_type: DType = DType.int64, linear_idx_type: DType = DType.int64]` A runtime-configurable layout that uses `RuntimeTuple` for storage. This struct provides a layout implementation that can be modified at runtime, unlike the static [`Layout`](/mojo/kernels/layout/layout/Layout) type. It uses [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple) for shape and stride storage. The layout must have statically known dimensions at compile time, but the actual shape and stride values can be modified during execution. ## Parameters * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static `Layout` type to base this runtime layout on. * ​element\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the each dimension element. Must be signed. * ​linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the linear index into memory returned by `crd2idx`. Must be signed. ## Fields * ​shape (`RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type].ShapeType`): The shape of the layout as a runtime tuple. Stores the size of each dimension. Uses the specified bitwidth and is unsigned. Must match the static layout's shape dimensions. * ​stride (`RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type].StrideType`): The stride of the layout as a runtime tuple. Stores the stride (step size) for each dimension. Uses 64-bit unsigned integers since strides can be large values. Must match the static layout's stride dimensions. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ShapeType` `comptime ShapeType = RuntimeTuple[layout.shape, element_type=element_type]` Type alias for the runtime shape tuple. ### `StrideType` `comptime StrideType = RuntimeTuple[layout.stride, element_type=linear_idx_type]` Type alias for the runtime stride tuple. ## Methods ### `__init__` `__init__() -> Self` Initialize a `RuntimeLayout` with default values. Creates a new `RuntimeLayout` instance with default shape and stride values. Requires that the static layout has known dimensions at compile time. **Constraints:** The static layout that this runtime layout is based on must have all dimensions known. `__init__(shape: RuntimeTuple[layout.shape, element_type=element_type], stride: RuntimeTuple[layout.stride, element_type=linear_idx_type]) -> Self` Initialize a `RuntimeLayout` with specified shape and stride. **Args:** * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the dimensions of each axis. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the stride values for each axis. ### `__call__` `__call__(self, idx: Int) -> Scalar[linear_idx_type]` Convert a single index to a flat linear index. **Args:** * ​idx ([`Int`](/mojo/std/builtin/int/Int)): The one-dimensional index to convert. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The corresponding flat linear index in the layout. `__call__[t: IntTuple](self, idx: RuntimeTuple[t, element_type=element_type]) -> Scalar[linear_idx_type]` Convert a multi-dimensional index to a flat linear index. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type for the index. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the multi-dimensional coordinates. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The corresponding flat linear index in the layout. ### `idx2crd` `idx2crd[t: IntTuple](self, idx: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[idx2crd(t, layout.shape, layout.stride), element_type=element_type]` Converts a linear index to logical coordinates. This is the inverse operation of the **call** method, mapping from a memory index back to the corresponding logical coordinates. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type for the index. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): The logical coordinates corresponding to the given index. ### `size` `size(self) -> Int` Calculate the total number of elements in the layout. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The product of all dimensions in the shape, representing the total number of elements that can be addressed by this layout. ### `bound_check_required` `bound_check_required(self) -> Bool` Determine if bounds checking is required for this layout. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if any dimension in the shape differs from the static layout's shape, False otherwise. ### `cast` `cast[_element_type: DType, /, *, target_linear_idx_type: DType = linear_idx_type](self) -> RuntimeLayout[layout, element_type=_element_type, linear_idx_type=target_linear_idx_type]` Cast the layout to use a different element bitwidth. **Parameters:** * ​\_element\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The target data type. * ​target\_linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The target linear idx type. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with the shape cast to the specified type. ### `__str__` `__str__(self) -> String` Convert the layout to a string representation. **Returns:** `String`: A string representation of the layout. ### `row_major` `static row_major[rank: Int, //](shape: IndexList[rank, element_type=element_type]) -> Self` Create a row-major layout from the given shape. In row-major layout, elements with adjacent rightmost indices are adjacent in memory. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the layout. **Args:** * ​shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): An `IndexList` containing the dimensions of each axis. **Returns:** `Self`: A `RuntimeLayout` with row-major stride ordering. ### `col_major` `static col_major[rank: Int, //](shape: IndexList[rank, element_type=element_type]) -> Self` Create a column-major layout from the given shape. In column-major layout, elements with adjacent leftmost indices are adjacent in memory. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions in the layout. **Args:** * ​shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): An `IndexList` containing the dimensions of each axis. **Returns:** `Self`: A `RuntimeLayout` with column-major stride ordering. ### `write_to` `write_to(self, mut writer: T)` Write a string representation of the layout to a writer. **Args:** * ​writer (`T`): The `Writer` object to write the layout representation to. ### `sublayout` `sublayout[i: Int](self) -> RuntimeLayout[layout[i], element_type=element_type, linear_idx_type=linear_idx_type]` Extract a nested sublayout at the specified index. **Parameters:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the nested layout to extract. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A `RuntimeLayout` representing the nested layout at index i. ### `dim` `dim(self, i: Int) -> Int` Get the size of the dimension at the specified index. **Args:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the dimension to retrieve. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the dimension at index `i`. ### `__len__` `static __len__() -> Int` Get the number of dimensions in the layout. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of dimensions (rank) of the layout.
--- ## coalesce (Runtime_layout)
`coalesce[l: Layout, keep_rank: Bool = False](layout: RuntimeLayout[l, element_type=element_type, linear_idx_type=linear_idx_type]) -> RuntimeLayout[coalesce(l, keep_rank), element_type=element_type, linear_idx_type=linear_idx_type]` Coalesce adjacent dimensions in a runtime layout when possible. This optimizes the layout by merging adjacent dimensions when their relationship allows it, potentially reducing the number of dimensions. **Parameters:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type to coalesce. * ​keep\_rank ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to maintain the original rank (currently unsupported). **Args:** * ​layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The input `RuntimeLayout` to coalesce. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with coalesced dimensions.
--- ## runtime_layout
Provides the `RuntimeLayout` type and functions for working with it. You can use `RuntimeLayout` to define a layout where the dimensions are not known at compile time. You can import these APIs from `layout.runtime_layout`. ```mojo from layout.runtime_layout import RuntimeLayout, make_layout ``` ## Structs * [​`RuntimeLayout`](./RuntimeLayout): A runtime-configurable layout that uses `RuntimeTuple` for storage. ## Functions * [​`coalesce`](./coalesce): Coalesce adjacent dimensions in a runtime layout when possible. * [​`make_layout`](./make_layout): Combine two runtime layouts into a single composite layout.
--- ## make_layout (Runtime_layout)
`make_layout[l1: Layout, l2: Layout, /, *, linear_idx_type: DType = DType.uint64](a: RuntimeLayout[l1, element_type=element_type, linear_idx_type=linear_idx_type], b: RuntimeLayout[l2, element_type=element_type, linear_idx_type=linear_idx_type]) -> RuntimeLayout[make_layout(l1, l2), element_type=element_type, linear_idx_type=linear_idx_type]` Combine two runtime layouts into a single composite layout. This creates a new layout by concatenating the dimensions and strides of the input layouts. **Parameters:** * ​l1 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type of `a`. * ​l2 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type of `b`. * ​linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer type of the all index calculated by the returned runtime layout. **Args:** * ​a ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The first `RuntimeLayout` to combine. * ​b ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The second `RuntimeLayout` to combine. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with dimensions from both input layouts.
--- ## RuntimeTuple
`@register_passable(trivial)` `struct RuntimeTuple[S: IntTuple = -1, /, *, element_type: DType = DType.int64]` A struct representing tuple-like data with compile-time and runtime elements. RuntimeTuple combines static (compile-time) and dynamic (runtime) handling of tuple-like data structures, typically used for tensor shapes, indices, and coordinates in high-performance computing contexts. This struct is optimized for parallel execution and hardware acceleration, allowing efficient manipulation of multi-dimensional data. It supports both known compile-time values and runtime-determined values. ## Parameters * ​S ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): `IntTuple` with compile-time known values (or `UNKNOWN_VALUE` for runtime values). * ​element\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Integer type of the underlying elements. ## Fields * ​value (`IndexList[RuntimeTuple[S, element_type=element_type].scalar_length, element_type=element_type]`): Storage for the actual tuple values, implemented as an IndexList with the appropriate size and element type. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Intable`](/mojo/std/builtin/int/Intable), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Sized`](/mojo/std/builtin/len/Sized), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `scalar_length` `comptime scalar_length = len[IntTuple](flatten(S))` The total number of scalar elements in this RuntimeTuple after flattening nested tuples. ## Methods ### `__init__` `__init__() -> Self` Initialize a `RuntimeTuple` with default values. For dimensions with known compile-time values in S, uses those values. For unknown dimensions, initializes them to UNKNOWN\_VALUE. `__init__(*values: Int) -> Self` Initialize a `RuntimeTuple` with the provided values. **Args:** * ​\*values ([`Int`](/mojo/std/builtin/int/Int)): Variadic number of integer values to initialize the tuple with. `@implicit` `__init__[l: Int](values: IndexList[l, element_type=element_type]) -> Self` Initialize a `RuntimeTuple` from an `IndexList`. **Parameters:** * ​l ([`Int`](/mojo/std/builtin/int/Int)): Compile-time length of the input `IndexList`. **Args:** * ​values ([`IndexList`](/mojo/std/utils/index_/IndexList)): `IndexList` to initialize from. Must have same length as the `RuntimeTuple`. The values will be cast to the appropriate element type if needed. ### `__getitem__` `__getitem__[i: Int](self) -> RuntimeTuple[S[i], element_type=element_type]` Retrieves the element at the specified index in the tuple. This method provides array-like indexing for RuntimeTuple, allowing access to individual elements or sub-tuples. It handles the internal offset calculation to access the correct elements in the flattened storage array. **Parameters:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to retrieve. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the element or sub-tuple at the specified index. ### `__setitem__` `__setitem__[i: Int](mut self, val: Scalar[element_type])` Sets the value of the element at the specified index in the tuple. This method enables array-like assignment for RuntimeTuple elements, handling the internal offset calculation to modify the correct element in the flattened storage array. **Parameters:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The index of the element to modify. **Args:** * ​val ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The new value to assign to the element. ### `offset_until` `static offset_until[i: Int]() -> Int` Calculates the offset in the flattened value array for a given tuple index. This method computes the sum of lengths of all flattened subtuple elements that come before the specified index, which is used for indexing into the internal storage. **Parameters:** * ​i ([`Int`](/mojo/std/builtin/int/Int)): The tuple index to calculate the offset for. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The offset in the flattened array where the i-th element begins. ### `get_int` `get_int(self) -> Scalar[element_type]` Returns the integer value of this RuntimeTuple. For tuples with a known compile-time value, returns that value. For tuples with a runtime value, returns the first element of the internal storage array. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The integer value of this RuntimeTuple. ### `__str__` `__str__(self) -> String` Converts the RuntimeTuple to its string representation. This method provides a human-readable string representation of the tuple, which is useful for debugging and logging. **Returns:** `String`: A string representation of the `RuntimeTuple`. ### `concat` `concat[R: IntTuple](self, rhs: RuntimeTuple[R, element_type=element_type]) -> RuntimeTuple[concat(S, R), element_type=element_type]` Concatenates two `RuntimeTuple`s together. This method combines the current `RuntimeTuple` with another one, preserving both compile-time and runtime values. It handles the complexity of merging the underlying storage arrays while maintaining the proper semantic structure. **Parameters:** * ​R ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type parameter of the right-hand side RuntimeTuple. **Args:** * ​rhs ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to concatenate to the end of this one. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing all elements from both tuples in sequence. ### `flatten` `flatten(self) -> RuntimeTuple[flatten(S), element_type=element_type]` Flattens a potentially nested `RuntimeTuple` into a single-level tuple. This method converts a hierarchical structure of tuples into a flat representation, preserving all values but removing the nested structure. This is useful for operations that need to treat all elements uniformly. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing all elements in a flat (non-nested) structure. ### `write_to` `write_to(self, mut writer: T)` Writes the RuntimeTuple to a Writer object. This method is used by the string conversion system to generate a string representation of the RuntimeTuple. It handles both scalar values and nested tuple structures, producing a properly formatted output. **Args:** * ​writer (`T`): The Writer object to write the string representation to. ### `__len__` `__len__(self) -> Int` Returns the length (number of top-level elements) of the `RuntimeTuple`. This method provides the standard Python-like len() functionality, giving the number of elements at the top level of the tuple structure. For nested tuples, this returns the number of first-level entries, not the total number of scalar values. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of top-level elements in the tuple. ### `cast` `cast[dtype: DType](self) -> RuntimeTuple[S, element_type=dtype]` Casts the RuntimeTuple to use a different numeric type. This method creates a new RuntimeTuple with the same structure and values but using a different underlying numeric type for storage. This is useful for changing precision or signedness of the data. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The target DType to cast the elements to. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` with elements cast to the specified type. ### `__int__` `__int__(self) -> Int` Converts the RuntimeTuple to an integer value. This method enables implicit conversion of a RuntimeTuple to an integer, but is constrained to only work on scalar tuples (those that contain a single value). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The integer value of the tuple.
--- ## coalesce_nested_tuple
`coalesce_nested_tuple[t: IntTuple, out_t: IntTuple = _int_tuple_product_flatten[t]()](tuple: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[out_t]` Coalesces a nested `RuntimeTuple` into a single-level `RuntimeTuple`, by multiplying all the values together. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The underlying Compile-time IntTuple backing the RuntimeTuple. * ​out\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The flattened Compile-time IntTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The RuntimeTuple to convert. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `IntTuple` containing the products of each top level tuple, in a flat structure.
--- ## concat
`concat(var lhs: IntTuple, rhs: IntTuple) -> IntTuple` Concatenates two `IntTuple` instances into a single `IntTuple`. This function appends all elements from the right-hand side tuple to the left-hand side tuple, creating a new combined tuple. The operation preserves the hierarchical structure of both tuples. **Args:** * ​lhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The left-hand side `IntTuple` that will be modified (var parameter). * ​rhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The right-hand side `IntTuple` whose elements will be appended. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing all elements from both tuples in sequence.
--- ## crd2idx (Runtime_tuple)
`crd2idx[crd_t: IntTuple, shape_t: IntTuple, stride_t: IntTuple, out_type: DType = DType.uint64](crd: RuntimeTuple[crd_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type], stride: RuntimeTuple[stride_t, element_type=element_type]) -> Scalar[out_type]` Converts multi-dimensional coordinates to a linear index. This function is the inverse of idx2crd, transforming a set of coordinates into a flat index based on the provided shape and stride information. This is essential for mapping multi-dimensional tensor elements to linear memory. **Parameters:** * ​crd\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the coordinates. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the shape. * ​stride\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the stride. * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The output data type for the index (default: uint64). **Args:** * ​crd ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The coordinates to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The stride values for each dimension. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): A scalar value representing the linear index corresponding to the given coordinates.
--- ## idx2crd (Runtime_tuple)
`idx2crd[idx_t: IntTuple, shape_t: IntTuple, stride_t: IntTuple](idx: RuntimeTuple[idx_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type], stride: RuntimeTuple[stride_t, element_type=element_type]) -> RuntimeTuple[idx2crd(idx_t, shape_t, stride_t), element_type=element_type]` Converts a linear index to multi-dimensional coordinates. This function transforms a flat index into coordinate values based on the provided shape and stride information. This is essential for mapping linear memory accesses to multi-dimensional tensor elements. **Constraints:** The index must be a scalar value (not a tuple). **Parameters:** * ​idx\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the index. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the shape. * ​stride\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the stride. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The stride values for each dimension. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A `RuntimeTuple` containing the multi-dimensional coordinates. `idx2crd[idx_t: IntTuple, shape_t: IntTuple](idx: RuntimeTuple[idx_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type]) -> RuntimeTuple[idx2crd(idx_t, shape_t, prefix_product(shape_t)), element_type=element_type]` Converts a linear index to multi-dimensional coordinates using shape-derived strides. This is a convenience overload of `idx2crd` that automatically calculates the stride values from the shape using `prefix_product`. This is the common case for row-major storage order tensors. **Parameters:** * ​idx\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the index. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the shape. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A `RuntimeTuple` containing the multi-dimensional coordinates calculated using automatically derived strides from the shape.
--- ## runtime_tuple
Provides the `RuntimeTuple` data structure and related utility functions for handling tuple-like data with both compile-time and runtime elements. `RuntimeTuple` is designed for high-performance tensor operations, supporting efficient manipulation of multi-dimensional data structures like shapes, indices, and coordinates. Key features: * Hybrid compile-time/runtime value handling * Optimized for parallel execution and hardware acceleration * Support for nested tuple structures * Efficient conversion between linear indices and multi-dimensional coordinates * Specialized operations for tensor shape calculations The module includes functions for tuple manipulation (concatenation, flattening), coordinate transformations (`idx2crd`, `crd2idx`), and specialized tensor operations like shape division and prefix products. ## Structs * [​`RuntimeTuple`](./RuntimeTuple): A struct representing tuple-like data with compile-time and runtime elements. RuntimeTuple combines static (compile-time) and dynamic (runtime) handling of tuple-like data structures, typically used for tensor shapes, indices, and coordinates in high-performance computing contexts. This struct is optimized for parallel execution and hardware acceleration, allowing efficient manipulation of multi-dimensional data. It supports both known compile-time values and runtime-determined values. ## Functions * [​`coalesce_nested_tuple`](./coalesce_nested_tuple): Coalesces a nested `RuntimeTuple` into a single-level `RuntimeTuple`, by multiplying all the values together. * [​`concat`](./concat): Concatenates two `IntTuple` instances into a single `IntTuple`. * [​`crd2idx`](./crd2idx): Converts multi-dimensional coordinates to a linear index. * [​`idx2crd`](./idx2crd): Converts a linear index to multi-dimensional coordinates. This function transforms a flat index into coordinate values based on the provided shape and stride information. This is essential for mapping linear memory accesses to multi-dimensional tensor elements. * [​`is_int`](./is_int): Determines if a `RuntimeTuple` represents a scalar integer value. * [​`is_tuple`](./is_tuple): Determines if a `RuntimeTuple` represents a tuple rather than a scalar value. * [​`prefix_product`](./prefix_product): Computes the prefix products of elements in the `RuntimeTuple`. * [​`product`](./product): Computes the product of all elements in the `RuntimeTuple`. * [​`shape_div`](./shape_div): Performs specialized shape division between `RuntimeTuple`s. * [​`signum`](./signum): Returns the sign of an integer value. * [​`to_index_list`](./to_index_list): Converts a RuntimeTuple to an IndexList with the same values.
--- ## is_int (Runtime_tuple)
`is_int[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Bool` Determines if a `RuntimeTuple` represents a scalar integer value. This function checks if the `RuntimeTuple` holds a single scalar value rather than a tuple structure with multiple elements. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `RuntimeTuple` represents a scalar integer, False otherwise.
--- ## is_tuple (Runtime_tuple)
`is_tuple[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Bool` Determines if a `RuntimeTuple` represents a tuple rather than a scalar value. This function checks the structure of the underlying IntTuple to determine if it represents a tuple with multiple elements or a single scalar value. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to check. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the `RuntimeTuple` represents a tuple, False if it represents a scalar.
--- ## prefix_product (Runtime_tuple)
`prefix_product[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[prefix_product(t)]` Computes the prefix products of elements in the `RuntimeTuple`. This function calculates the running product of elements, where each output element is the product of all previous elements in the input. This is commonly used in tensor computations to calculate stride values. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the input RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The input `RuntimeTuple`. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the prefix products of the input elements.
--- ## product (Runtime_tuple)
`product[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Int` Computes the product of all elements in the `RuntimeTuple`. This function multiplies all scalar values in the tuple, including those in nested tuples after flattening. This is commonly used to calculate the total size of a tensor from its shape. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the input RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The input `RuntimeTuple`. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The product of all scalar elements in the tuple.
--- ## shape_div (Runtime_tuple)
`shape_div[a_t: IntTuple, b_t: IntTuple](a: RuntimeTuple[a_t, element_type=element_type], b: RuntimeTuple[b_t, element_type=element_type]) -> RuntimeTuple[shape_div(a_t, b_t)]` Performs specialized shape division between `RuntimeTuple`s. This function implements a special division operation specifically designed for tensor shape calculations. Unlike standard division, it handles special cases: 1. If shapes are directly divisible (a % b == 0), returns a standard division (a // b) 2. If shapes are inversely divisible (b % a == 0), returns the signed reciprocal 3. If shapes are incompatible, aborts with an error This operation is essential for transformations between tensor layouts and computing broadcasting semantics. **Parameters:** * ​a\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the first operand. * ​b\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the second operand. **Args:** * ​a ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The dividend `RuntimeTuple`. * ​b ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The divisor `RuntimeTuple`. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the result of the shape division.
--- ## signum (Runtime_tuple)
`signum(a: Int) -> Int` Returns the sign of an integer value. This helper function determines whether a number is positive, negative, or zero, returning 1 for positive, -1 for negative, and 0 for zero. **Args:** * ​a ([`Int`](/mojo/std/builtin/int/Int)): The integer value to determine the sign of. **Returns:** [`Int`](/mojo/std/builtin/int/Int): 1 if a > 0, -1 if a < 0, 0 if a == 0.
--- ## to_index_list (Runtime_tuple)
`to_index_list[rank: Int, t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> IndexList[rank]` Converts a RuntimeTuple to an IndexList with the same values. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the resulting IndexList. * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple template parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The RuntimeTuple to convert. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): An IndexList filled with the values of the RuntimeTuple.
--- ## ComposedLayout
`struct ComposedLayout[LayoutA: LayoutTrait, LayoutB: LayoutTrait, offset: Optional[Int] = 0]` Layout composed of two layouts applied sequentially. Combines two layouts. Output of the first (`LayoutA`) is input to the second (`LayoutB`), with optional offset in between. ## Parameters * ​LayoutA ([`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait)): The first layout to apply. * ​LayoutB ([`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait)): The second layout to apply. * ​offset ([`Optional`](/mojo/std/collections/optional/Optional)): Optional offset between layouts (default: 0). ## Fields * ​layout\_a (`LayoutA`): The first layout to apply. * ​layout\_b (`LayoutB`): The second layout to apply. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = LayoutB.__del__is_trivial if LayoutA.__del__is_trivial else LayoutA.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = LayoutB.__moveinit__is_trivial if LayoutA.__moveinit__is_trivial else LayoutA.__moveinit__is_trivial` ### `has_shape` `comptime has_shape = LayoutA.has_shape if LayoutA.has_shape else LayoutB.has_shape` True if either layout has a shape. ## Methods ### `__init__` `__init__(out self, var layout_a: LayoutA, var layout_b: LayoutB)` Initialize ComposedLayout with two layouts. **Args:** * ​layout\_a (`LayoutA`): The first layout. * ​layout\_b (`LayoutB`): The second layout. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy constructor for ComposedLayout. **Args:** * ​other (`Self`): The ComposedLayout to copy from. ### `__call__` `__call__(self, idx: IntTuple) -> Int` Apply composed layout to an index. Applies `LayoutA`, then adds offset, then applies `LayoutB`. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The index to transform. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The transformed index. `__call__(self, idx: IntTuple, offset_val: Int) -> Int` Apply composed layout with runtime offset. Applies `LayoutA`, then adds runtime `offset_val`, then `LayoutB`. Static offset must not be set when using runtime offset. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The index to transform. * ​offset\_val ([`Int`](/mojo/std/builtin/int/Int)): Runtime offset to apply. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The transformed index. ### `size` `size(self) -> Int` Get the size of the composed layout. Returns the size of the first layout (`LayoutA`). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the first layout. ### `cosize` `cosize(self) -> Int` Get the cosize of the composed layout. Returns the cosize of the second layout (`LayoutB`). **Returns:** [`Int`](/mojo/std/builtin/int/Int): The cosize of the second layout.
--- ## Swizzle
`@register_passable(trivial)` `struct Swizzle` Swizzle functor for memory access pattern optimization. Implements a swizzling pattern to reduce bank conflicts in shared memory accesses. It XORs specific bits of memory indices based on configurable parameters. Swizzle operation: Given index `i`, and Swizzle\[bits, base, shift]: 1. Extract `bits` number of bits from `i` starting from position `base + max(shift, 0)`. Let's call this `YYY`. 2. Extract `bits` number of bits from `i` starting from position `base - min(shift, 0)`. Let's call this `ZZZ`. 3. Result is `i ^ (YYY shifted by 'shift' positions)`. Example (Swizzle\[2, 0, 3]): Input index bits: `xxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxxx` Output index bits: `xxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxxx` where `AA = ZZ ^ YY`. Attributes: bits (Int): Number of bits in the mask (YYY). base (Int): Number of least significant bits to keep constant. shift (Int): Shift distance for the mask (positive: right, negative: left). yyy\_mask (Int): Mask for the bits to be shifted (YYY). zzz\_mask (Int): Mask for the target bits (ZZZ). ## Fields * ​bits (`Int`): Number of bits in the mask. * ​base (`Int`): Number of least significant bits to keep constant. * ​shift (`Int`): Distance to shift the mask (pos right, neg left). * ​yyy\_mask (`Int`): Mask for the bits to be shifted. * ​zzz\_mask (`Int`): Mask for the target bits. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `has_shape` `comptime has_shape = False` Indicates if layout has shape. Swizzle always False. ## Methods ### `__init__` `__init__(bits: Int, base: Int, shift: Int) -> Self` Initialize a Swizzle object. Configures the swizzle operation based on bits, base, and shift parameters. **Args:** * ​bits ([`Int`](/mojo/std/builtin/int/Int)): Number of bits in the mask. * ​base ([`Int`](/mojo/std/builtin/int/Int)): Least significant bits to keep constant. * ​shift ([`Int`](/mojo/std/builtin/int/Int)): Distance to shift the mask. ### `__call__` `__call__(self, index: IntTuple) -> Int` Apply swizzle to an IntTuple index. Unwraps the IntTuple and applies the swizzle to the integer value. **Args:** * ​index ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple index to swizzle. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The swizzled index value. `__call__(self, offset: Int) -> Int` Apply swizzle to an integer offset. Performs the swizzle operation on an integer offset to rearrange memory access patterns. **Args:** * ​offset ([`Int`](/mojo/std/builtin/int/Int)): The integer offset to swizzle. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The swizzled offset value. `__call__(self, offset: Scalar[dtype]) -> Scalar[dtype]` Apply swizzle to a scalar offset. Scalar version of the swizzle operation. Applies swizzle to a scalar offset. **Args:** * ​offset ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar offset to swizzle. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The swizzled scalar value. ### `size` `size(self) -> Int` Get the size of the swizzle pattern. Calculates the size of the memory region affected by the swizzle pattern. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The size of the swizzle pattern. ### `cosize` `cosize(self) -> Int` Get the cosize of the swizzle pattern. Cosize is the same as size for swizzle layouts, representing the output size. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The cosize of the swizzle pattern (same as size). ### `write_to` `write_to(self, mut writer: T)` Write the swizzle parameters to a writer. Outputs the swizzle parameters (bits, base, shift) in a tuple format. **Args:** * ​writer (`T`): The writer to write to. ### `__str__` `__str__(self) -> String` Convert the swizzle to a string representation. **Returns:** `String`: String representation of the swizzle parameters.
--- ## eval_composed
`eval_composed[composed_layout: ComposedLayout[Layout, Swizzle]](idx: Scalar[DType.uint], offset: Scalar[DType.uint] = 0) -> UInt` Evaluate a composed layout with swizzle. Evaluates a `ComposedLayout[Layout, Swizzle]`. Applies the base layout, adds an optional offset, and then applies the swizzle. **Parameters:** * ​composed\_layout ([`ComposedLayout`](/mojo/kernels/layout/swizzle/ComposedLayout)): The composed layout to evaluate, consisting of a base Layout and a Swizzle transformation. **Args:** * ​idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The input index to transform. * ​offset ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Optional offset to apply between layouts (default: 0). **Returns:** `UInt`: The transformed index after applying both layouts.
--- ## swizzle (Swizzle)
Defines swizzle layouts for optimizing memory access patterns. This module is designed for use in shared memory, especially in GPU kernels, to reduce bank conflicts. It provides tools to create and apply swizzle transformations to memory indices. Swizzling rearranges memory access order to distribute accesses across different memory banks. This mitigates bank contention and improves memory access efficiency. Module components: * `Swizzle` struct: Represents a swizzle transformation with configurable bits, base, and shift parameters. * Helper functions: `make_ldmatrix_swizzle`, `make_swizzle` create predefined swizzle patterns. These are optimized for scenarios like `ldmatrix` instructions and general 2D memory access. * `ComposedLayout` struct: Combines a base layout with a swizzle layout for complex memory access optimizations. Primary use case: GPU kernel development where shared memory bank conflicts can degrade performance. Applying swizzle layouts optimizes memory access patterns for higher throughput. ## Structs * [​`ComposedLayout`](./ComposedLayout): Layout composed of two layouts applied sequentially. * [​`Swizzle`](./Swizzle): Swizzle functor for memory access pattern optimization. ## Functions * [​`eval_composed`](./eval_composed): Evaluate a composed layout with swizzle. * [​`make_ldmatrix_swizzle`](./make_ldmatrix_swizzle): Make swizzle to avoid bank conflict for ldmatrix ops. * [​`make_swizzle`](./make_swizzle): Create a 2D swizzle to avoid bank conflicts. * [​`shiftl`](./shiftl): Shift left or right based on sign of shift amount. * [​`shiftr`](./shiftr): Shift right or left based on sign of shift amount.
--- ## make_ldmatrix_swizzle
`make_ldmatrix_swizzle[dtype: DType, row_size: Int, log2_vector_width: Int = 0]() -> Swizzle` Make swizzle to avoid bank conflict for ldmatrix ops. Creates a swizzle pattern optimized for `ldmatrix` operations. Minimizes bank conflicts in shared memory for these operations. Calculates swizzle parameters based on data type and row size. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements. * ​row\_size ([`Int`](/mojo/std/builtin/int/Int)): Size of each row in elements. * ​log2\_vector\_width ([`Int`](/mojo/std/builtin/int/Int)): Log2 of the vector width (default: 0). **Returns:** `Swizzle`: A `Swizzle` object configured for `ldmatrix`.
--- ## make_swizzle
`make_swizzle[num_rows: Int, row_size: Int, access_size: Int]() -> Swizzle` Create a 2D swizzle to avoid bank conflicts. Generates a swizzle pattern for 2D memory layout to minimize bank conflicts in shared memory access. **Parameters:** * ​num\_rows ([`Int`](/mojo/std/builtin/int/Int)): Number of rows in the minimum access pattern. * ​row\_size ([`Int`](/mojo/std/builtin/int/Int)): Size of each row in elements. * ​access\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements accessed at once. **Returns:** `Swizzle`: A `Swizzle` object for 2D memory access. `make_swizzle[dtype: DType, mode: TensorMapSwizzle]() -> Swizzle` Create swizzle based on predefined swizzle modes. Returns a swizzle pattern based on standard modes (32B, 64B, 128B, none), adjusted for data type. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the elements. * ​mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzle mode to use (TensorMapSwizzle enum). **Returns:** `Swizzle`: A `Swizzle` object configured by the specified mode.
--- ## shiftl
`shiftl(a: Int, s: Int) -> Int` Shift left or right based on sign of shift amount. Performs a left shift if `s` is positive, or a right shift if `s` is negative. **Args:** * ​a ([`Int`](/mojo/std/builtin/int/Int)): The integer value to shift. * ​s ([`Int`](/mojo/std/builtin/int/Int)): The shift amount. Positive for left, negative for right. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The shifted integer value. `shiftl(a: Scalar[dtype], s: Scalar[dtype]) -> Scalar[dtype]` Shift left/right based on sign of shift for scalars. Scalar version of `shiftl`. Left shift if `s` is positive, right shift if `s` is negative. **Args:** * ​a ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to shift. * ​s ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar shift amount. Positive for left, negative right. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The shifted scalar value.
--- ## shiftr
`shiftr(a: Int, s: Int) -> Int` Shift right or left based on sign of shift amount. Performs a right shift if `s` is positive, or a left shift if `s` is negative. **Args:** * ​a ([`Int`](/mojo/std/builtin/int/Int)): The integer value to shift. * ​s ([`Int`](/mojo/std/builtin/int/Int)): The shift amount. Positive for right, negative for left. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The shifted integer value. `shiftr(a: Scalar[dtype], s: Scalar[dtype]) -> Scalar[dtype]` Shift right/left based on sign of shift for scalars. Scalar version of `shiftr`. Right shift if `s` is positive, left shift if `s` is negative. **Args:** * ​a ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar value to shift. * ​s ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The scalar shift amount. Positive for right, negative left. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The shifted scalar value.
--- ## TensorCore
`struct TensorCore[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool = False]` TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations. This struct encapsulates the functionality required to efficiently map matrix operations to Tensor Cores on NVIDIA and AMD GPUs. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results with hardware-specific optimizations. Note: Different shapes and data types are supported depending on the GPU hardware. For NVIDIA GPUs: * float32: 16x8x8 or 16x8x4 * half-precision: 16x8x16 * float8: 16x8x32 For AMD GPUs: * float32: 16x16x4 * half-precision: 16x16x16 or 32x32x8 ## Parameters * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type for output/accumulation operations. * ​in\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type for input matrix elements. * ​shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape parameters for the matrix operation in the form \[M, N, K] where MxN is the output shape and K is the inner dimension. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to transpose the B matrix before multiplication. Defaults to False. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_reg_type` `comptime a_reg_type = SIMD[in_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](2)]()]` SIMD type for the A operand registers. ### `b_reg_type` `comptime b_reg_type = SIMD[in_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](2), shape.__getitem__[3, DType.int64, Int](1)]()]` SIMD type for the B operand registers. ### `c_reg_tile_type` `comptime c_reg_tile_type = LayoutTensor[out_type, Layout.col_major(1, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]` LayoutTensor type for the C register tile. ### `c_reg_type` `comptime c_reg_type = SIMD[out_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](1)]()]` SIMD type for the C/accumulator operand registers. ### `supported_fp32` `comptime supported_fp32 = (shape == IndexList[3, DType.int64](16, 8, 8, Tuple[]())) if is_nvidia_gpu() else (shape == IndexList[3, DType.int64](16, 16, 4, Tuple[]())) if (in_type == DType.float32)._mlir_value else (in_type == DType.float32)` Whether float32 is supported for this tensor core configuration. ### `supported_fp64` `comptime supported_fp64 = Tuple[IndexList[3], IndexList[3], IndexList[3], IndexList[3]](VariadicPack[True, MutExternalOrigin, True, Movable, IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape_8x8x4, shape_16x8x4, shape_16x8x8, shape_16x8x16)).__contains__[IndexList[3], IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape) if (out_type == DType.float64) if (in_type == DType.float64)._mlir_value else (in_type == DType.float64) else (out_type == DType.float64) if (in_type == DType.float64)._mlir_value else (in_type == DType.float64) if is_nvidia_gpu() else False` Whether float64 is supported for this tensor core configuration. ### `supported_fp8` `comptime supported_fp8 = (shape == shape_16x8x32) if Tuple[DType, DType](VariadicPack[True, MutExternalOrigin, True, Movable, DType, DType](DType.float8_e4m3fn, DType.float8_e5m2)).__contains__[DType, DType, DType](in_type) else Tuple[DType, DType](VariadicPack[True, MutExternalOrigin, True, Movable, DType, DType](DType.float8_e4m3fn, DType.float8_e5m2)).__contains__[DType, DType, DType](in_type) if is_nvidia_gpu() else (shape == shape_16x16x32) if Tuple[DType, DType](VariadicPack[True, MutExternalOrigin, True, Movable, DType, DType](get_amd_fp8_dtype(), get_amd_bf8_dtype())).__contains__[DType, DType, DType](in_type) else Tuple[DType, DType](VariadicPack[True, MutExternalOrigin, True, Movable, DType, DType](get_amd_fp8_dtype(), get_amd_bf8_dtype())).__contains__[DType, DType, DType](in_type)` Whether float8 is supported for this tensor core configuration. ### `supported_half` `comptime supported_half = (shape == shape_16x8x16) if is_nvidia_gpu() else Tuple[IndexList[3], IndexList[3], IndexList[3], IndexList[3]](VariadicPack[True, MutExternalOrigin, True, Movable, IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape_16x16x16, shape_16x16x32, shape_32x32x8, shape_32x32x16)).__contains__[IndexList[3], IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape) if in_type.is_half_float() else in_type.is_half_float()` Whether half-precision float is supported for this configuration. ## Methods ### `__init__` `__init__(out self)` Initialize a new TensorCore instance. ### `get_shapes` `static get_shapes[_out_type: DType, _in_type: DType]() -> List[IndexList[3]]` Get supported shapes for given data types. Returns a list of valid shapes for the specified output and input data types. Note: The returned shapes are hardware-dependent. Different shapes are supported for different combinations of input and output types. **Parameters:** * ​\_out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The output/accumulation data type. * ​\_in\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The input matrix data type. **Returns:** [`List`](/mojo/std/collections/list/List): List\[IndexList\[3]]: Valid shapes for the matrix operations given the specified types. ### `load_a` `load_a[swizzle: Optional[Swizzle] = None](self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout, shape](), MutAnyOrigin, address_space=AddressSpace.LOCAL]` Load the A matrix fragments. Loads matrix A from memory into a LayoutTensor suitable for tensor core operations. **Parameters:** * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzle pattern for optimal memory access (AMD only). **Args:** * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix A data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The loaded matrix fragments as a `LayoutTensor`. `load_a[swizzle: Optional[Swizzle] = None](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: Scalar[DType.uint] = 0)` Load A matrix fragments from shared memory. Optimized version for loading A matrix fragments from shared memory. **Parameters:** * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional memory access pattern for to optimize memory bandwidth. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source data in shared memory. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor for fragments. * ​mma\_tile\_coord\_k ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The K coordinate of the MMA tile. Defaults to 0. ### `load_b` `load_b[swizzle: Optional[Swizzle] = None](self, b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout, shape, transpose_b](), MutAnyOrigin, address_space=AddressSpace.LOCAL]` Load the B matrix fragments. Loads matrix B from memory into a `LayoutTensor` suitable for tensor core operations. The function handles different hardware architectures and memory access patterns. Note: If transpose\_b is `True`, the B matrix will be transposed during loading. This is more efficient than transposing the matrix in memory. **Parameters:** * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzle pattern for optimal memory access (AMD only). Will cause an error if used with NVIDIA GPUs. **Args:** * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix B data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The loaded matrix fragments as a `LayoutTensor`. `load_b[swizzle: Optional[Swizzle] = None](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: Scalar[DType.uint] = 0, warp_tile_coord_n: Scalar[DType.uint] = 0)` Load B matrix fragments from shared memory into registers for tensor core operations. This function loads matrix B fragments from a warp tile in shared memory into register fragments for use in tensor core matrix multiply operations. It handles hardware-specific optimizations for both NVIDIA and AMD GPUs. Note: The `warp_tile` must be in shared memory. For NVIDIA GPUs, `swizzle` must be `None`. For AMD GPUs, providing an appropriate `swizzle` pattern can improve performance. **Parameters:** * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional memory access pattern for AMD GPUs to optimize memory bandwidth. Must be None when running on NVIDIA GPUs. For NVIDIA GPUs, swizzle is always on. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source `LayoutTensor` in shared memory containing the B matrix data. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination `LayoutTensor` to store the loaded matrix fragments. * ​mma\_tile\_coord\_k ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K-dimension coordinate within the warp tile. Defaults to 0. * ​warp\_tile\_coord\_n ([`Scalar`](/mojo/std/builtin/simd/#scalar)): N-dimension coordinate within the warp tile. Defaults to 0. `load_b(self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scales: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: Scalar[DType.uint] = 0)` Load quantized B matrix fragments from shared memory with dequantization. This function loads int4 quantized matrix B fragments from shared memory, dequantizes them using the provided scales, and stores the result in register fragments for tensor core operations. Notes: * The `warp_tile` must be in shared memory. * The `fragments` and `scales` must be in local memory. * This function only supports half-precision data types (bfloat16, float16). * The quantized data is stored as int4 values packed into int32 elements. * Each thread processes multiple fragments by unpacking and dequantizing the int4 values. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source `LayoutTensor` in shared memory containing the quantized B matrix data. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination `LayoutTensor` to store the dequantized matrix fragments. * ​scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): `LayoutTensor` containing the scaling factors for dequantization. * ​mma\_tile\_coord\_k ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K-dimension coordinate within the warp tile. Defaults to 0. ### `load_c` `load_c(self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TensorCore[out_type, in_type, shape, transpose_b].c_reg_tile_type` Load the C matrix fragments. Loads matrix C from memory into a `LayoutTensor` suitable for tensor core operations. The function handles different hardware architectures and memory access patterns. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix C data. **Returns:** `TensorCore`: The loaded matrix fragments as a `LayoutTensor`. ### `store_d` `store_d(self, d_dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Store matrix D to destination memory. Stores the result matrix D from tensor core computation to the destination memory. **Args:** * ​d\_dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor to store the result. * ​d\_src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor containing the computed result. ### `mma_op` `mma_op(self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TensorCore[out_type, in_type, shape, transpose_b].c_reg_tile_type` Perform matrix multiply-accumulate operation (MMA). Executes `D = A * B + C` using tensor cores. **Args:** * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The A matrix input. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The B matrix input. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The C matrix input for accumulation. **Returns:** `TensorCore`: `Self.c_reg_tile_type`: The result of the MMA operation. ### `mma` `mma(self, a_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform matrix multiply-accumulate operation using tensor cores. Executes C = A \* B + C using tensor cores, where A, B, and C are matrix fragments stored in register memory. This function handles the mapping of fragments to hardware tensor core operations. Notes: * All fragments must be properly loaded using the corresponding load functions. * The function assumes fragments are vectorized layout tensors with dimensions num\_vectors x 1. * The c\_frag shape\[0] must equal num\_m\_mmas \* num\_n\_mmas. * The result is accumulated in-place in c\_frag. **Args:** * ​a\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A fragments as a `LayoutTensor`. * ​b\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B fragments as a `LayoutTensor`. * ​c\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix C fragments as a `LayoutTensor` for both input and output.
--- ## TiledTensorCore
`struct TiledTensorCore[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]` TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension. Enables larger K dimension operations by decomposing them into multiple smaller MMA operations. Currently only being used for AMD GPUs to enable 16x16x32 operations using two 16x16x16 MMAs. ## Parameters * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type for output/accumulation operations. * ​in\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type for input matrix elements. * ​shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape parameters for individual MMA operations \[M, N, K]. * ​group\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA operations along the K dimension. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to transpose the b matrix. Defaults to False. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_op` `comptime mma_op = TensorCore[out_type, in_type, shape, transpose_b]()` The underlying TensorCore instance for MMA operations. ## Methods ### `mma` `static mma[swap_a_b: Bool = False](a_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform multiple matrix multiply-accumulate operations along the K dimension. Executes group\_size MMA operations, processing slices of the K dimension and accumulating results in c\_reg\_tile. **Parameters:** * ​swap\_a\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to swap a and b operands. Defaults to False. **Args:** * ​a\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix a fragments \[num\_m\_mmas, group\_size \* a\_frag\_size]. * ​b\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix b fragments \[num\_n\_mmas, group\_size \* b\_frag\_size]. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Accumulation matrix c fragments, modified in-place.
--- ## get_fragment_size
`get_fragment_size[mma_shape: IndexList[3]]() -> IndexList[3]` Calculates the fragment size per thread for a given MMA shape. For tensor core operations, each thread in a warp handles a portion of the computation. This function determines how many elements each thread needs to process for the A, B, and C/D matrices based on the MMA shape. **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): An `IndexList[3]` containing the MMA dimensions \[M, N, K]. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): An `IndexList[3]` containing the fragment sizes per thread for matrices A, B, and C/D respectively, calculated as: `[M*K/WARP_SIZE, N*K/WARP_SIZE, M*N/WARP_SIZE]`.
--- ## get_mma_shape
`get_mma_shape[input_type: DType, accum_type: DType, shape_id: Int = 0]() -> IndexList[3]` Returns the appropriate matrix multiply-accumulate (MMA) shape for tensor core operations. Selects the optimal MMA shape based on the GPU architecture, input data type, accumulation data type, and optional shape identifier. This function handles different configurations for both NVIDIA and AMD GPUs. **Parameters:** * ​input\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the input matrices (A and B). * ​accum\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type used for accumulation (C and D). * ​shape\_id ([`Int`](/mojo/std/builtin/int/Int)): Optional identifier to select between multiple valid shapes (default: 0). **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): An `IndexList[3]` containing the MMA dimensions in the format `[M, N, K]`, where `MxN` is the output matrix size and `K` is the reduction dimension.
--- ## tensor_core
Tensor Core Module for High-Performance Matrix Operations Provides abstractions for using GPU Tensor Cores to perform optimized matrix operations. It supports both NVIDIA and AMD GPU architectures with hardware-specific optimizations. ## Key Components: * `TensorCore`: Core struct that encapsulates tensor core operations with support for various data types and matrix shapes. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results. * Matrix Fragment Management: Functions for loading and storing matrix fragments to/from shared memory with hardware-specific optimizations. * Matrix Multiply-Accumulate (MMA): Optimized implementations of matrix multiplication operations using tensor cores. ## Supported Operations: * Matrix loading with various layouts and swizzling patterns * Matrix multiply-accumulate (D = A \* B + C) * Matrix storing with hardware-specific optimizations ## Supported Data Types: * NVIDIA: float32, bfloat16, float16, float8\_e4m3fn, float8\_e5m2 * AMD: float32, bfloat16, float16 ## Supported Matrix Shapes: * NVIDIA: 16x8x8, 16x8x4, 16x8x16, 8x8x4, 16x8x32 * AMD: 16x16x4, 16x16x16, 32x32x8 ## `comptime` values ### `shape_16x16x16` `comptime shape_16x16x16 = IndexList[3, DType.int64](16, 16, 16, Tuple[]())` AMDGPU tensor core shape 16x16x16. ### `shape_16x16x32` `comptime shape_16x16x32 = IndexList[3, DType.int64](16, 16, 32, Tuple[]())` AMDGPU tensor core shape 16x16x32. ### `shape_16x16x4` `comptime shape_16x16x4 = IndexList[3, DType.int64](16, 16, 4, Tuple[]())` AMDGPU tensor core shape 16x16x4. ### `shape_16x8x16` `comptime shape_16x8x16 = IndexList[3, DType.int64](16, 8, 16, Tuple[]())` Tensor core shape 16x8x16. ### `shape_16x8x32` `comptime shape_16x8x32 = IndexList[3, DType.int64](16, 8, 32, Tuple[]())` Tensor core shape 16x8x32. ### `shape_16x8x4` `comptime shape_16x8x4 = IndexList[3, DType.int64](16, 8, 4, Tuple[]())` Tensor core shape 16x8x4. ### `shape_16x8x8` `comptime shape_16x8x8 = IndexList[3, DType.int64](16, 8, 8, Tuple[]())` Tensor core shape 16x8x8. ### `shape_32x32x16` `comptime shape_32x32x16 = IndexList[3, DType.int64](32, 32, 16, Tuple[]())` AMDGPU tensor core shape 32x32x16. ### `shape_32x32x8` `comptime shape_32x32x8 = IndexList[3, DType.int64](32, 32, 8, Tuple[]())` AMDGPU tensor core shape 32x32x8. ### `shape_8x8x4` `comptime shape_8x8x4 = IndexList[3, DType.int64](8, 8, 4, Tuple[]())` Tensor core shape 8x8x4. ### `shape_null` `comptime shape_null = IndexList[3, DType.int64](0, 0, 0, Tuple[]())` Null tensor core shape (0x0x0). ## Structs * [​`TensorCore`](./TensorCore): TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations. * [​`TiledTensorCore`](./TiledTensorCore): TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension. ## Functions * [​`get_fragment_size`](./get_fragment_size): Calculates the fragment size per thread for a given MMA shape. * [​`get_mma_shape`](./get_mma_shape): Returns the appropriate matrix multiply-accumulate (MMA) shape for tensor core operations. * [​`load_b_nt`](./load_b_nt): Loads the b operand tile for AMD tensor core MFMA from (N, K) storage. * [​`load_b_tr`](./load_b_tr): Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. * [​`num_matrix_reg`](./num_matrix_reg): Calculates the number of matrix registers required per thread.
--- ## load_b_nt
`load_b_nt[mma_shape: IndexList[3], swizzle: Optional[Swizzle] = Optional[Swizzle]()](tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, 8]` Loads the b operand tile for AMD tensor core MFMA from (N, K) storage. This function supports double-rate MFMA shapes (32x32x16, 16x16x32) with bfloat16 input. Unlike load\_b\_tr which expects (K, N) storage, this function works with (N, K) storage which is common when transpose\_b=True and B is stored row-major. The input tile (shape = (mma\_shape\[1], mma\_shape\[2])) is split along the K dimension into two halves of shape (MMA\_N, MMA\_K//2). Each half is loaded using `_load_tr16_b64_warp`, which performs a transposed (column-major) load from shared memory. The hardware transpose effectively converts the (N, K) storage to (K, N) format needed by MMA. Example: For 16x16x32 MMA with B stored as (N, K) = (16, 32) in LDS: ```mojo # B tile in LDS: shape (16, 32) = (MMA_N, MMA_K) var b_tile = smem_b.tile[16, 32](n_idx, k_idx) var b_reg = load_b_nt[IndexList[3](16, 16, 32)](b_tile) # b_reg now contains 8 bf16 values ready for MFMA ``` **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The MMA instruction tile shape (only 32x32x16 or 16x16x32 supported). * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzle pattern for bank-conflict-free LDS access. **Args:** * ​tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A `LayoutTensor`, residing in shared memory, with shape (mma\_shape\[1], mma\_shape\[2]) and dtype `DType.bfloat16`. This is (N, K) storage order. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): SIMD\[tile.dtype, 8]: Concatenated transposed SIMD loads from both halves of the tile.
--- ## load_b_tr
`load_b_tr[mma_shape: IndexList[3], swizzle: Optional[Swizzle] = Optional[Swizzle]()](tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, 8]` Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. This function supports double-rate MFMA shapes (32x32x16, 16x16x32) with bfloat16 input. The input tile (shape = (mma\_shape\[2], mma\_shape\[1])) is split along the K dimension into two halves of shape (MMA\_K//2, MMA\_N). Each half is loaded using `_load_tr16_b64_warp`, which performs a transposed (column-major) load from shared memory. The resulting two 4-element SIMD vectors are concatenated into a single `SIMD[tile.dtype, 8]` vector. **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The MMA instruction tile shape (only 32x32x16 or 16x16x32 supported). * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzle pattern for bank-conflict-free LDS access. **Args:** * ​tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A `LayoutTensor`, residing in shared memory, with shape (mma\_shape\[2], mma\_shape\[1]) and dtype `DType.bfloat16`. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): SIMD\[tile.dtype, 8]: Concatenated transposed SIMD loads from both halves of the tile.
--- ## num_matrix_reg
`num_matrix_reg[dim_1: Int, dim_2: Int]() -> Int` Calculates the number of matrix registers required per thread. Determines how many registers each thread in a warp needs to store a matrix of the given dimensions. This is calculated by dividing the total number of elements (dim\_1 \* dim\_2) by the warp size, as the matrix is distributed across all threads in the warp. **Parameters:** * ​dim\_1 ([`Int`](/mojo/std/builtin/int/Int)): First dimension of the matrix. * ​dim\_2 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension of the matrix. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The number of matrix registers needed per thread.
--- ## TensorCoreAsync
`struct TensorCoreAsync[c_type: DType, a_type: DType, b_type: DType, mma_shape: IndexList[3], /, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, transpose_b: Bool = False]` High-performance asynchronous tensor core operations for matrix multiplication. This struct provides methods for utilizing NVIDIA's Tensor Cores for asynchronous matrix multiplication operations, with support for various data types and swizzling configurations. ## Parameters * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the output matrix C. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the input matrix A. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the input matrix B. * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): Dimensions for the matrix multiply-accumulate (MMA) operation as \[M, N, K]. * ​a\_swizzle ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling mode for matrix A (default: SWIZZLE\_NONE). * ​b\_swizzle ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling mode for matrix B (default: SWIZZLE\_NONE). * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to transpose matrix B (default: False). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` Initialize the `TensorCoreAsync` instance. Ensures that the provided MMA shape is supported. Note: Fails to compile if `mma_shape` is not supported. ### `wgmma` `static wgmma[num_warp_groups: Int = 1, scale_c: Int = 1, scale_a: Int = 1, scale_b: Int = 1, num_k_iters: Optional[Int] = None](a_smem_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], wg_idx: Int = 0)` Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA). This method handles the case where both A and B matrices are in shared memory. **Parameters:** * ​num\_warp\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of warp groups to distribute work across (default: 1). * ​scale\_c ([`Int`](/mojo/std/builtin/int/Int)): Scale factor for matrix C. Valid values are 1 or 0 (default: 1). * ​scale\_a ([`Int`](/mojo/std/builtin/int/Int)): Scale factor for matrix A. Valid values are 1 or -1 (default: 1). * ​scale\_b ([`Int`](/mojo/std/builtin/int/Int)): Scale factor for matrix B. Valid values are 1 or -1 (default: 1). * ​num\_k\_iters ([`Optional`](/mojo/std/collections/optional/Optional)): Number of iterations for the K dimension. This is useful to save computation when we pad shared memory. (default: None which is just `a_smem_layout[1].size() // mma_shape[2]`). **Args:** * ​a\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A in shared memory. * ​b\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B in shared memory. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C in register memory. * ​wg\_idx ([`Int`](/mojo/std/builtin/int/Int)): Warp group index for multi-warp group scenarios (default: 0). `static wgmma(a_frag_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA). This overloaded method handles the case where matrix A is in register memory and matrix B is in shared memory. **Args:** * ​a\_frag\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A in register memory. * ​b\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B in shared memory. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C in register memory. ### `arrive` `static arrive()` Ensures memory consistency by creating a fence for WGMMA operations. This method should be called before committing a group to ensure all shared memory accesses are properly aligned and visible. ### `commit_group` `static commit_group()` Commits the current warp group for execution. This synchronizes the warp group and commits all pending WGMMA operations that have been previously issued. ### `wait_group` `static wait_group[group: Int = 0]()` Waits for the completion of a specific warp group's operations. This method blocks until all WGMMA operations from the specified group are complete. **Parameters:** * ​group ([`Int`](/mojo/std/builtin/int/Int)): The group ID to wait for (default: 0).
--- ## tensor_core_async
Tensor Core Async Module This module provides high-performance abstractions for utilizing NVIDIA's Tensor Cores to perform asynchronous matrix multiplication operations. It implements optimized memory layouts and access patterns for efficient tensor core computations. Key components: * Layout creation functions for K-major and MN-major memory arrangements * Swizzling support for improved memory access patterns * WGMMA (Warp Group Matrix Multiply-Accumulate) descriptor generation * TensorCoreAsync struct with methods for asynchronous matrix multiplication The module supports various data types, matrix dimensions, and memory configurations, enabling efficient implementation of deep learning primitives and other tensor operations that can leverage hardware acceleration. Performance features: * Asynchronous execution model to overlap computation and memory access * Support for different swizzling modes to optimize memory bandwidth * Efficient register and shared memory utilization * Support for multi-warp group execution This implementation is specifically optimized for NVIDIA GPUs with Tensor Core support. ## `comptime` values ### `WGMMA_K_BYTES` `comptime WGMMA_K_BYTES = 32` Size of WGMMA K dimension in bytes. ## Structs * [​`TensorCoreAsync`](./TensorCoreAsync): High-performance asynchronous tensor core operations for matrix multiplication. ## Functions * [​`select_k_atom`](./select_k_atom): Creates a core matrix layout for tensor core operations. * [​`st_matrix_m_atom`](./st_matrix_m_atom): Creates a layout for M-major `st_matrix` atom in the context of WGMMA C matrix. * [​`st_matrix_m_layout`](./st_matrix_m_layout): Creates a layout for M-major `st_matrix` in the context of WGMMA C matrix. This meant to be used with swapAB, since the C matrix must be transposed during the write phase. This must also be used in conjuction with st\_matrix transposed modifier. * [​`st_matrix_n_atom`](./st_matrix_n_atom): Creates a layout for N-major `st_matrix` atom in the context of WGMMA C matrix. * [​`st_matrix_n_layout`](./st_matrix_n_layout): Creates a layout for N-major `st_matrix` in the context of WGMMA C matrix. * [​`tile_layout_k_major`](./tile_layout_k_major): Creates a K-major layout for tensor core operations. * [​`tile_layout_mn_major`](./tile_layout_mn_major): Creates an MN-major layout for tensor core operations. * [​`tile_sf_layout_k_major`](./tile_sf_layout_k_major): Creates a K-major layout for tensor core scale factors. * [​`tile_to_descriptor`](./tile_to_descriptor): Transforms a layout into a WGMMA descriptor-compatible layout. * [​`warpgroup_fence`](./warpgroup_fence): Code motion fence to ensure the registers of the WGMMA instruction do not get touched by anything. * [​`wgmma_c_layout`](./wgmma_c_layout): Generates three layouts for mapping WGMMA C matrix coordinates. * [​`wgmma_c_thread_layout`](./wgmma_c_thread_layout): Returns the thread layout component for WGMMA C matrix. * [​`wgmma_output_layout`](./wgmma_output_layout): Returns the output layout component for WGMMA C matrix.
--- ## select_k_atom
`select_k_atom[dtype: DType, swizzle_mode: TensorMapSwizzle]() -> Layout` Creates a core matrix layout for tensor core operations. Constructs the fundamental atomic layout for tensor core operations based on the specified data type and swizzle mode. This layout represents the minimal dense matrix structure that can be efficiently processed by tensor cores. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type of the tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A core matrix layout optimized for tensor core operations.
--- ## st_matrix_m_atom
`st_matrix_m_atom[num_stmatrix: Int, num_consumer: Int]() -> Layout` Creates a layout for M-major `st_matrix` atom in the context of WGMMA C matrix. The domain of this layout is the warp group local thread index. Thus, the layout takes \[0, 128) as input and returns an offset for a logical array with an element size of 128-bit. Assume num\_consumer = 2, and num\_stmatrix = 2 then a single atom for one warp would look like this Each block contains the thread\_idx, each thread idx will hold the address of the next 128-bit fragment. \| 0 | 8 | \| 1 | 9 | \| 2 | 10 | \| ... | ... | \| 7 | 15 | \| 16 | 24 | \| 17 | 25 | \| 18 | 26 | \| ... | ... | \| 23 | 31 | All 4 warps in the warp group will then be laid out next to each other \| w1 | w2 | w3 | w4 | **Parameters:** * ​num\_stmatrix ([`Int`](/mojo/std/builtin/int/Int)): Number of N-dimension tiles in the C matrix. * ​num\_consumer ([`Int`](/mojo/std/builtin/int/Int)): Number of consumers. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with an element size of 128-bit.
--- ## st_matrix_m_layout
`st_matrix_m_layout[c_type: DType, WG_BM: Int, num_m_mmas: Int, num_consumer: Int]() -> Layout` Creates a layout for M-major `st_matrix` in the context of WGMMA C matrix. This meant to be used with swapAB, since the C matrix must be transposed during the write phase. This must also be used in conjuction with st\_matrix transposed modifier. The M-dimension tiling size `WG_BM // 16`, the number of MMA tiles `num_m_mmas` in the N-dimension, and the number of consumers `num_consumer`. The output is an offset for a logical array with the element type `c_type`. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the C matrix. * ​WG\_BM ([`Int`](/mojo/std/builtin/int/Int)): Size of the K dimension in the C matrix in shared memory. * ​num\_m\_mmas ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA tiles in the M dimension. * ​num\_consumer ([`Int`](/mojo/std/builtin/int/Int)): Number of consumers. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with the element type `c_type`.
--- ## st_matrix_n_atom
`st_matrix_n_atom[num_stmatrix: Int]() -> Layout` Creates a layout for N-major `st_matrix` atom in the context of WGMMA C matrix. The domain of this layout is the warp group local thread index. Thus, the layout takes \[0, 128) as input and returns an offset for a logical array with an element size of 128-bit. **Parameters:** * ​num\_stmatrix ([`Int`](/mojo/std/builtin/int/Int)): Number of N-dimension tiles in the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with an element size of 128-bit.
--- ## st_matrix_n_layout
`st_matrix_n_layout[c_type: DType, WG_BN: Int, num_m_mmas: Int, num_consumer: Int]() -> Layout` Creates a layout for N-major `st_matrix` in the context of WGMMA C matrix. The layout modes are: the warp group local thread index, the N-dimension tiling size `WG_BN // 16`, the number of MMA tiles `num_m_mmas` in the M-dimension, and the number of consumers `num_consumer`. The output is an offset for a logical array with the element type `c_type`. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the C matrix. * ​WG\_BN ([`Int`](/mojo/std/builtin/int/Int)): Size of the K dimension in the C matrix in shared memory. * ​num\_m\_mmas ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA tiles in the M dimension. * ​num\_consumer ([`Int`](/mojo/std/builtin/int/Int)): Number of consumers. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with the element type `c_type`.
--- ## tile_layout_k_major
`tile_layout_k_major[dtype: DType, BM: Int, BK: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]() -> Layout` Creates a K-major layout for tensor core operations. Constructs a layout optimized for K-major access patterns in tensor core operations, with optional swizzling for improved memory access patterns. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type of the tensor. * ​BM ([`Int`](/mojo/std/builtin/int/Int)): Size of the M dimension in the tile. * ​BK ([`Int`](/mojo/std/builtin/int/Int)): Size of the K dimension in the tile. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode (default: SWIZZLE\_NONE). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A K-major layout configured for the specified dimensions and swizzle mode.
--- ## tile_layout_mn_major
`tile_layout_mn_major[dtype: DType, mn_dim: Int, k_dim: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]() -> Layout` Creates an MN-major layout for tensor core operations. Constructs a unit layout optimized for MN-major access patterns in shared memory, with optional swizzling for improved memory access patterns. Note: This returns the "unit" layout; the actual shared memory layout can be a multiple of this unit. Currently only supports SWIZZLE\_NONE and SWIZZLE\_128B modes. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type of the tensor. * ​mn\_dim ([`Int`](/mojo/std/builtin/int/Int)): Size of the MN dimension. * ​k\_dim ([`Int`](/mojo/std/builtin/int/Int)): Size of the K dimension. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode (default: SWIZZLE\_NONE). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - An MN-major layout configured for the specified dimensions and swizzle mode.
--- ## tile_sf_layout_k_major
`tile_sf_layout_k_major[BM: Int, BK: Int, SF_SCALE_SIZE: Int]() -> Layout` Creates a K-major layout for tensor core scale factors. Constructs a layout for K-major access patterns for scale factors. **Parameters:** * ​BM ([`Int`](/mojo/std/builtin/int/Int)): Size of the M dimension in the tile. * ​BK ([`Int`](/mojo/std/builtin/int/Int)): Size of the K dimension in the tile. * ​SF\_SCALE\_SIZE ([`Int`](/mojo/std/builtin/int/Int)): Number of elements in a scale factor vector. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A K-major layout configured for the specified dimensions and scale factor size.
--- ## tile_to_descriptor
`tile_to_descriptor[dtype: DType, layout: Layout, is_k_major: Bool = True]() -> Layout` Transforms a layout into a WGMMA descriptor-compatible layout. Converts a standard layout into a form that can be used with WGMMA descriptors, handling both K-major and MN-major layouts differently. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type of the tensor. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Input layout to transform. * ​is\_k\_major ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the layout is K-major (True) or MN-major (False). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): \`Layout - A transformed layout compatible with WGMMA descriptors.
--- ## warpgroup_fence
`warpgroup_fence[accum_type: DType, accum_layout: Layout, //](accum: LayoutTensor[accum_type, accum_layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Code motion fence to ensure the registers of the WGMMA instruction do not get touched by anything. This has no impact on kernel correctness. It serves purely as an NVVM code motion barrier, preventing other operations from modifying the WGMMA instruction's registers during execution of the WGMMA instruction batch. **Parameters:** * ​accum\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type of the tensor. * ​accum\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Register layout of the accumulator. **Args:** * ​accum ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A LayoutTensor with the accum\_type and accum\_layout.
--- ## wgmma_c_layout
`wgmma_c_layout[mma_m: Int, mma_n: Int, C: Layout]() -> List[Layout]` Generates three layouts for mapping WGMMA C matrix coordinates. This function creates three layout mappings that are essential for working with WGMMA (Warp Group Matrix Multiply-Accumulate) operations: 1. A projection layout that maps linearized indices to row coordinates (i) 2. A projection layout that maps linearized indices to column coordinates (j) 3. A composite layout that maps thread and vector coordinates to linearized indices across multiple MMA tiles These layouts are particularly useful for operations like attention masking and matrix multiplication epilogues, where register values need to be mapped to the coordinate system of the C matrix. Note: This function enforces constraints on the WGMMA dimensions and ensures the C matrix dimensions are compatible with the WGMMA instruction size. **Parameters:** * ​mma\_m ([`Int`](/mojo/std/builtin/int/Int)): The M dimension (rows) of a single WGMMA instruction, must be 64. * ​mma\_n ([`Int`](/mojo/std/builtin/int/Int)): The N dimension (columns) of a single WGMMA instruction, must be multiple of 8. * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix within a thread block. **Returns:** [`List`](/mojo/std/collections/list/List): `List[Layout]` - A list containing three layouts: 1. proj\_i: Maps linearized indices to row coordinates 2. proj\_j: Maps linearized indices to column coordinates 3. TV\_tile\_to\_idx: Maps thread/vector/tile coordinates to linearized indices
--- ## wgmma_c_thread_layout
`wgmma_c_thread_layout[C: Layout]() -> Layout` Returns the thread layout component for WGMMA C matrix. Generates the first mode of the WGMMA C layout, which maps thread coordinates to linearized indices in the output matrix. **Parameters:** * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout mapping thread coordinates to linearized indices.
--- ## wgmma_output_layout
`wgmma_output_layout[mma_n: Int, C: Layout]() -> Layout` Returns the output layout component for WGMMA C matrix. Generates the second mode of the WGMMA C layout, which maps output vector coordinates to linearized indices in the output matrix. **Parameters:** * ​mma\_n ([`Int`](/mojo/std/builtin/int/Int)): The N dimension of the WGMMA instruction. * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout mapping output vector coordinates to linearized indices.
--- ## PipelineState
`@register_passable(trivial)` `struct PipelineState[num_stages: Int]` Manages state for a multi-stage pipeline with circular buffer semantics. PipelineState provides a mechanism for tracking the current stage in a multi-stage pipeline, particularly useful for double or triple buffering in GPU tensor operations. It maintains an index that cycles through the available stages, a phase bit that toggles when the index wraps around, and a monotonically increasing count. This struct is commonly used with TMA operations to coordinate the use of multiple buffers in a pipeline fashion, allowing for overlapping computation and data transfer. ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): The number of stages in the pipeline (e.g., 2 for double buffering, 3 for triple buffering). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Initialize a PipelineState with default values. Creates a new PipelineState with index 0, phase 0, and count 0. `__init__(index: Int, phase: Int, count: Int) -> Self` Initialize a PipelineState with specific values. Creates a new PipelineState with the specified index, phase, and count. **Args:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): The initial stage index. * ​phase ([`Int`](/mojo/std/builtin/int/Int)): The initial phase value (0 or 1). * ​count ([`Int`](/mojo/std/builtin/int/Int)): The initial count value. ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32): The current index value, which ranges from 0 to num\_stages-1. ### `phase` `phase(self) -> UInt32` Get the current phase bit. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32): The current phase value (0 or 1), which toggles when the index wraps around. ### `step` `step(mut self)` Advance the pipeline state to the next stage. Increments the index and count. When the index reaches num\_stages, it wraps around to 0 and toggles the phase bit. This function is used to move to the next buffer in a multi-buffer pipeline, implementing circular buffer semantics. ### `next` `next(mut self) -> Self` Advance the pipeline state to the next stage and return the new state. This function is used to move to the next buffer in a multi-buffer pipeline, implementing circular buffer semantics. **Returns:** `Self`: The new pipeline state after advancing to the next stage. ### `__enter__` `__enter__(var self) -> Self` Enter the context manager. **Returns:** `Self`: The pipeline state instance for use in a `with` statement.
--- ## RaggedTMA3DTile
`struct RaggedTMA3DTile[dtype: DType, swizzle_mode: TensorMapSwizzle, BM: Int, BN: Int]` Creates a TMA descriptor for loading/storing from ragged 3D arrays with a ragged leading dimension. This loads 2D tiles, indexing into the middle dim. When using this loads, it is essential that at least `BM * stride` space has been allocated in front of the gmem pointer, otherwise `CUDA_ERROR_ILLEGAL_ADDRESS` may result. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode to use for memory access. * ​BM ([`Int`](/mojo/std/builtin/int/Int)): The number of rows of the corresponding 2D shared memory tile. * ​BN ([`Int`](/mojo/std/builtin/int/Int)): The number of columns of the corresponding 2D shared memory tile. ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor that will be used to store the ragged tensor. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = RaggedTMA3DTile[dtype, swizzle_mode, BM, BN]` The device-side type representation. ### `layout` `comptime layout = tile_layout_k_major[dtype, BM, BN, swizzle_mode]()` The unswizzled-smem layout copied to/from by this tma op. ### `swizzle_granularity` `comptime swizzle_granularity = (swizzle_mode.bytes() // size_of[dtype]())` The number of columns that must be copied at a time due to the swizzle size. ## Methods ### `__init__` `@implicit` `__init__(out self, descriptor: TMADescriptor)` Initializes a new TMATensorTile with the provided TMA descriptor. **Args:** * ​descriptor ([`TMADescriptor`](/mojo/std/gpu/host/nvidia/tma/TMADescriptor)): The TMA descriptor that defines the memory access pattern. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy initializes this `RaggedTMA3DTile` from another instance. **Args:** * ​other (`Self`): The other `RaggedTMA3DTile` instance to copy from. ### `get_type_name` `static get_type_name() -> String` Returns a string representation of the RaggedTMA3DTile type. **Returns:** `String`: A string containing the type name with all template parameters. ### `create` `static create[*, depth: Int = BN](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], origin], *, rows: Int, middle_dim: Int) -> Self` Create a RaggedTMA3DTile. **Parameters:** * ​depth ([`Int`](/mojo/std/builtin/int/Int)): The size of the inner-most, contiguous, dimension. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context used to create the TMA descriptors. * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The global memory pointer. * ​rows ([`Int`](/mojo/std/builtin/int/Int)): The size of the ragged dimension. * ​middle\_dim ([`Int`](/mojo/std/builtin/int/Int)): The size of the middle dimension. **Returns:** `Self`: A RaggedTMA3DTile corresponding to the gmem. **Raises:** If TMA descriptor creation fails. ### `async_copy_to` `async_copy_to[cta_group: Int = 1](self, dst: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)` Copy from the `RaggedTMA3DTile` source to the smem destination. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The destination shared memory pointer to which we copy memory. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​ragged\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Index into the ragged dimension. * ​dynamic\_dim ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Number of rows to copy. * ​middle\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Index into the middle (generally head) dimension. ### `async_copy_from` `async_copy_from[eviction_policy: CacheEviction = CacheEviction.EVICT_FIRST](self, src: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)` Copy from the smem source to the `RaggedTMA3DTile` destination. **Parameters:** * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_FIRST. **Args:** * ​src ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The source shared memory pointer from which we copy memory. * ​ragged\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Index into the ragged dimension. * ​dynamic\_dim ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Number of rows to copy. * ​middle\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Index into the middle (generally head) dimension. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache.
--- ## RaggedTensorMap
`struct RaggedTensorMap[descriptor_rank: Int, //, dtype: DType, descriptor_shape: IndexList[descriptor_rank], remaining_global_dim_rank: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]` Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample. This struct only supports one dimension being ragged. The continous dimension (where stride is 1) cannot be ragged. ## Parameters * ​descriptor\_rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the descriptor shape (inferred). * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor. * ​descriptor\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the shared memory descriptor. * ​remaining\_global\_dim\_rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the remaining global tensor dimensions. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE\_NONE. ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor that will be used to store the ragged tensor. * ​max\_length (`Int`): The maximum length present in the sequences of the ragged tensor. * ​global\_shape (`IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]`): The shape of the global tensor. * ​global\_stride (`IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]`): The stride of the global tensor. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]` The TensorMapDescriptorArray type. ### `global_rank` `comptime global_rank = (remaining_global_dim_rank + 3)` The rank of the global tensor. ### `ragged_descriptor_shape` `comptime ragged_descriptor_shape = RaggedTensorMap._descriptor_shape[descriptor_rank, dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]()` The shape of the descriptor that will tile and load from shared -> global memory. ## Methods ### `__init__` `__init__(out self, ctx: DeviceContext, global_ptr: UnsafePointer[Scalar[dtype], origin], max_length: Int, ragged_stride: Int, batch_size: Int, global_last_dim: Int, remaining_global_dims: IndexList[remaining_global_dim_rank], remaining_global_stride: IndexList[remaining_global_dim_rank])` Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths. This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max\_descriptor\_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max\_descriptor\_length) while maintaining the same remaining tile shape specified by desc\_remaining\_tile\_shape. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context used to create the TMA descriptors. * ​global\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The source tensor in global memory that will be accessed using the descriptors. * ​max\_length ([`Int`](/mojo/std/builtin/int/Int)): The maximum length present in the sequences of the ragged tensor. * ​ragged\_stride ([`Int`](/mojo/std/builtin/int/Int)): The stride of the ragged dimension in the global tensor. * ​batch\_size ([`Int`](/mojo/std/builtin/int/Int)): The total number of sequences in the ragged tensor. * ​global\_last\_dim ([`Int`](/mojo/std/builtin/int/Int)): The last dimension of the global tensor. * ​remaining\_global\_dims ([`IndexList`](/mojo/std/utils/index_/IndexList)): The dimensions of the remaining global tensor. * ​remaining\_global\_stride ([`IndexList`](/mojo/std/utils/index_/IndexList)): The stride of the remaining global tensor. **Raises:** If the operation fails. ### `get_type_name` `static get_type_name() -> String` Returns a string representation of the TensorMapDescriptorArray type. **Returns:** `String`: A string containing the type name with all template parameters. ### `store_ragged_tile` `store_ragged_tile[rank: Int, //, using_max_descriptor_size: Bool = False](self, coordinates: IndexList[rank], preceding_cumulative_length: Int, store_length: Int, mut tile_iterator: LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Stores a ragged tile from shared memory to global memory. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the coordinates. * ​using\_max\_descriptor\_size ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, optimizes the store around the max descriptor size. **Args:** * ​coordinates ([`IndexList`](/mojo/std/utils/index_/IndexList)): The starting coordinates of all dimensions except the ragged dimension. * ​preceding\_cumulative\_length ([`Int`](/mojo/std/builtin/int/Int)): The cumulative length of the preceding sequences. * ​store\_length ([`Int`](/mojo/std/builtin/int/Int)): The length of the current sequence to be stored. * ​tile\_iterator ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The iterator over the tile in shared memory. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache.
--- ## SharedMemBarrier
`@register_passable(trivial)` `struct SharedMemBarrier` A hardware-accelerated synchronization primitive for GPU shared memory operations. This struct provides a barrier mechanism optimized for coordinating thread execution and memory transfers in GPU kernels, particularly for Tensor Memory Accelerator (TMA) operations. It enables efficient synchronization between threads and memory operations by leveraging hardware-specific barrier instructions. Key features: * Thread synchronization across thread blocks * Memory transfer completion tracking * Hardware-accelerated barrier operations * Support for phased synchronization This barrier is particularly useful for ensuring that shared memory operations complete before dependent computations begin, which is critical for maintaining data consistency in high-performance GPU kernels. ## Fields * ​mbar (`Int64`): Shared memory location used for the barrier state. This field stores an 8-byte aligned shared memory location that maintains the state of the barrier. The memory must be in shared address space to be accessible by all threads in a block. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `init` `init[o: MutOrigin](ref[o, AddressSpace._value._mlir_value] self, num_threads: Int32 = 1)` Initialize the barrier state with the expected number of threads. Sets up the barrier to expect arrivals from the specified number of threads before it can be satisfied. This is essential for coordinating thread synchronization in GPU kernels. **Parameters:** * ​o ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of self. **Args:** * ​num\_threads ([`Int32`](/mojo/std/builtin/simd/#int32)): Number of threads that must arrive at the barrier before it is satisfied. Defaults to 1. ### `expect_bytes` `expect_bytes[o: MutOrigin](ref[o, AddressSpace._value._mlir_value] self, bytes: Int32)` Configure the barrier to expect a specific number of bytes to be transferred. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred, enabling efficient coordination of memory operations. **Parameters:** * ​o ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of self. **Args:** * ​bytes ([`Int32`](/mojo/std/builtin/simd/#int32)): Number of bytes expected to be transferred. ### `expect_bytes_relaxed` `expect_bytes_relaxed[o: MutOrigin](ref[o, AddressSpace._value._mlir_value] self, bytes: Int32) -> UInt64` Configure the barrier to expect a specific number of bytes to be transferred. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred, enabling efficient coordination of memory operations. **Parameters:** * ​o ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of self. **Args:** * ​bytes ([`Int32`](/mojo/std/builtin/simd/#int32)): Number of bytes expected to be transferred. **Returns:** `UInt64`: The state. ### `arrive_and_expect_bytes` `arrive_and_expect_bytes[o: MutOrigin](ref[o, AddressSpace._value._mlir_value] self, bytes: Int32, cta_id: UInt32, pred: UInt32)` Configure the barrier to expect a specific number to bytes to be transferred at a remote CTA. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred at the specified CTA in the cluster. **Parameters:** * ​o ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of self. **Args:** * ​bytes ([`Int32`](/mojo/std/builtin/simd/#int32)): Number of bytes expected to be transferred. * ​cta\_id ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The CTA ID in a cluster to configure an arrival. * ​pred ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Predication on the arrival configuration instruction. Use UInt32 to match `selp.u32` in ptx. ### `wait` `wait[ticks: Optional[UInt32] = None](ref[AddressSpace._value._mlir_value] self, phase: UInt32 = 0)` Wait until the barrier is satisfied. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​ticks ([`Optional`](/mojo/std/collections/optional/Optional)): The number of ticks to wait before timing out in nanoseconds. Defaults to None. **Args:** * ​phase ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `wait_acquire` `wait_acquire[scope: Scope](ref[AddressSpace._value._mlir_value] self, phase: UInt32 = 0)` Acquire and wait until the barrier is satisfied. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​scope ([`Scope`](/mojo/std/gpu/intrinsics/Scope)): The scope of the barrier. **Args:** * ​phase ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `wait_relaxed` `wait_relaxed[scope: Scope](ref[AddressSpace._value._mlir_value] self, phase: UInt32 = 0)` Wait until the barrier is satisfied with relaxed ordering. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​scope ([`Scope`](/mojo/std/gpu/intrinsics/Scope)): The scope of the barrier. **Args:** * ​phase ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `try_wait` `try_wait(ref[AddressSpace._value._mlir_value] self, phase: UInt32 = 0) -> Bool` Non-blocking check if barrier phase is complete. Performs a single non-blocking check to see if the barrier has completed the specified phase. Returns immediately with the result without spinning. This is useful for implementing the try-acquire pattern where you want to overlap barrier checking with other useful work. Example: ```mojo # Try-acquire pattern for pipelined execution var ready = barrier.try_wait(phase) # Do other work while potentially waiting do_useful_work() # Now wait conditionally if not ready: barrier.wait(phase) ``` **Args:** * ​phase ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The phase parity (0 or 1) to check for. Defaults to 0. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the barrier phase is complete, False otherwise. ### `unsafe_ptr` `unsafe_ptr[origin: Origin[mut=mut]](ref[origin, AddressSpace._value._mlir_value] self) -> UnsafePointer[Int64, origin, address_space=AddressSpace.SHARED]` Get an unsafe pointer to the barrier's memory location. Provides low-level access to the shared memory location storing the barrier state. This method is primarily used internally by other barrier operations that need direct access to the underlying memory. **Parameters:** * ​origin ([`Origin`](/mojo/std/builtin/type_aliases/Origin)): Origin of self. **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer): An unsafe pointer to the barrier's memory location in shared memory, properly typed and aligned for barrier operations. ### `arrive_cluster` `arrive_cluster(ref[AddressSpace._value._mlir_value] self, cta_id: UInt32, count: UInt32 = 1)` Signal arrival at the barrier from a specific CTA (Cooperative Thread Array) in a cluster. This method is used in multi-CTA scenarios to coordinate barrier arrivals across different CTAs within a cluster. It enables efficient synchronization across thread blocks in clustered execution models. **Args:** * ​cta\_id ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The ID of the CTA (Cooperative Thread Array) that is arriving. * ​count ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The number of arrivals to signal. Defaults to 1. ### `arrive` `arrive[o: MutOrigin](ref[o, AddressSpace._value._mlir_value] self) -> Int` Signal arrival at the barrier and return the arrival count. This method increments the arrival count at the barrier and returns the updated count. It's used to track how many threads have reached the synchronization point. **Parameters:** * ​o ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of self. **Returns:** [`Int`](/mojo/std/builtin/int/Int): The updated arrival count after this thread's arrival.
--- ## TMATensorTile
`struct TMATensorTile[dtype: DType, layout: Layout, desc_layout: Layout = layout, is_k_major: Bool = True]` A hardware-accelerated tensor memory access (TMA) tile for efficient asynchronous data movement. The TMATensorTile struct provides a high-performance interface for asynchronous data transfers between global memory and shared memory in GPU tensor operations. It encapsulates a TMA descriptor that defines the memory access pattern and provides methods for various asynchronous operations. Performance: * Hardware-accelerated memory transfers using TMA instructions * Supports prefetching of descriptors for latency hiding * Enforces 128-byte alignment requirements for optimal memory access ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType The data type of the tensor elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the tile in shared memory, typically specified as row\_major. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA. * ​is\_k\_major ([`Bool`](/mojo/std/builtin/bool/Bool)): Bool = True Whether the shared memory is k-major. ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor that defines the memory access pattern. This field stores the hardware descriptor that encodes information about: * The source tensor's memory layout and dimensions * The tile shape and access pattern * Swizzling configuration for optimal memory access The descriptor is used by the GPU's Tensor Memory Accelerator hardware to efficiently transfer data between global and shared memory. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = TMATensorTile[dtype, layout, desc_layout, is_k_major]` The device-side type representation. ## Methods ### `__init__` `@implicit` `__init__(out self, descriptor: TMADescriptor)` Initializes a new TMATensorTile with the provided TMA descriptor. **Args:** * ​descriptor ([`TMADescriptor`](/mojo/std/gpu/host/nvidia/tma/TMADescriptor)): The TMA descriptor that defines the memory access pattern. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy initializes this `TMATensorTile` from another instance. **Args:** * ​other (`Self`): The other `TMATensorTile` instance to copy from. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache to reduce latency. This method helps hide memory access latency by prefetching the descriptor before it's needed for actual data transfers. ### `async_copy` `async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 2D coordinates in the source tensor from which to copy data. `async_copy[rank: Int, //, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, rank])` Schedules an asynchronous copy from global memory to shared memory for N-dimensional tensors. This is a generic dispatcher that selects the appropriate rank-specific async copy method based on the tensor rank. It provides a unified interface for initiating TMA transfers across 2D, 3D, 4D, and 5D tensors using `StaticTuple` coordinates. **Constraints:** * The rank must be 2, 3, 4, or 5. * The destination tensor must be 128-byte aligned in shared memory. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If set to 2, only the leader CTA needs to be notified upon completion. Defaults to 1. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple)): The N-dimensional coordinates in the source tensor from which to copy data, provided as a `StaticTuple` of `UInt32` values. ### `async_copy_3d` `async_copy_3d[eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified 3D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 3D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_FIRST. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 3D coordinates in the source tensor from which to copy data. ### `async_copy_4d` `async_copy_4d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified 4D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 4D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 4D coordinates in the source tensor from which to copy data. `async_copy_4d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified 4D coordinates. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 4D coordinates in the source tensor from which to copy data. ### `async_copy_5d` `async_copy_5d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified 5D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 5D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 5D coordinates in the source tensor from which to copy data. `async_copy_5d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int])` Schedules an asynchronous copy from global memory to shared memory at specified 5D coordinates. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT\_NORMAL. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 5D coordinates in the source tensor from which to copy data. ### `async_store` `async_store[rank: Int, //, cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: StaticTuple[UInt32, rank])` Schedules an asynchronous store from shared memory to global memory for N-dimensional tensors. This is a generic dispatcher that selects the appropriate rank-specific async store method based on the tensor rank. It provides a unified interface for initiating TMA store operations across 2D, 3D, 4D, and 5D tensors using `StaticTuple` coordinates. **Constraints:** * The rank must be 2, 3, 4, or 5. * The source tensor must be 128-byte aligned in shared memory. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group configuration for the store operation. Defaults to 1. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied to global memory. Must be 128-byte aligned. * ​coords ([`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple)): The N-dimensional coordinates in the destination global tensor where data will be stored, provided as a `StaticTuple` of `UInt32` values. `async_store(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Schedules an asynchronous store from shared memory to global memory. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to global memory at the specified coordinates. **Constraints:** The source tensor must be 128-byte aligned in shared memory. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tuple\[UInt, UInt] The 2D coordinates in the destination tensor where data will be stored. `async_store(self, src: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], coords: Tuple[UInt, UInt])` Schedules an asynchronous store from shared memory to global memory. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Args:** * ​src ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory from which data will be copied. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 2D coordinates in the destination tensor where data will be stored. ### `async_multicast_load` `async_multicast_load[cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous multicast load from global memory to multiple shared memory locations. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to multiple destination locations in shared memory across different CTAs (Cooperative Thread Arrays) as specified by the multicast mask. **Constraints:** The destination tensor must be 128-byte aligned in shared memory. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): SharedMemBarrierArray The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tuple\[UInt, UInt] The 2D coordinates in the source tensor from which to copy data. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): UInt16 A bit mask specifying which CTAs should receive the data. `async_multicast_load[cta_group: Int = 1](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous 2D multicast load from global to shared memory. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 2D coordinates in the source tensor from which to copy. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Bit mask specifying which CTAs should receive the data. ### `async_multicast_load_3d` `async_multicast_load_3d[cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous 3D multicast load from global memory to multiple shared memory locations. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to multiple destination locations in shared memory across different CTAs (Cooperative Thread Arrays) as specified by the multicast mask. **Constraints:** The destination tensor must be 128-byte aligned in shared memory. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): SharedMemBarrierArray The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tuple\[UInt, UInt, UInt] The 2D coordinates in the source tensor from which to copy data. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): UInt16 A bit mask specifying which CTAs should receive the data. `async_multicast_load_3d[cta_group: Int = 1](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous 3D multicast load from global to shared memory. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): If issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 3D coordinates in the source tensor from which to copy. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Bit mask specifying which CTAs should receive the data. ### `async_multicast_load_partitioned` `async_multicast_load_partitioned[tma_rows: Int, tma_load_size: Int](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, rank: Scalar[DType.uint], coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Performs a partitioned multicast load where each rank loads a distinct slice of data. This method is designed for clustered execution where different ranks (CTAs) load different, contiguous slices of the source tensor. Each rank's slice is offset by `rank * tma_rows` in the second dimension and stored at offset `rank * tma_load_size` in shared memory. Note: This is typically used in matrix multiplication kernels where the input matrices are partitioned across multiple CTAs for parallel processing. **Parameters:** * ​tma\_rows ([`Int`](/mojo/std/builtin/int/Int)): The number of rows each rank is responsible for loading. * ​tma\_load\_size ([`Int`](/mojo/std/builtin/int/Int)): The size in elements of each rank's slice in shared memory. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​rank ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The rank ID (0-based) that determines which slice to load. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The base 2D coordinates in the source tensor from which to copy data. The second coordinate will be offset by `rank * tma_rows`. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): A bit mask specifying which CTAs should receive the data. ### `async_store_3d` `async_store_3d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 3D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 3D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 3D coordinates in the destination tensor where data will be stored. `async_store_3d(self, src: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], coords: Tuple[UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at 3D coordinates. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). **Args:** * ​src ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory from which data will be copied. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 3D coordinates in the destination tensor. ### `async_store_4d` `async_store_4d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 4D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 4D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 4D coordinates in the destination tensor where data will be stored. ### `async_store_5d` `async_store_5d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 5D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 5D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 5D coordinates in the destination tensor where data will be stored. ### `async_reduce` `async_reduce[reduction_kind: ReduceOp](self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Schedules an asynchronous reduction operation from shared memory to global memory. This method initiates a hardware-accelerated asynchronous reduction operation that combines data from shared memory with data in global memory using the specified reduction operation. The reduction is performed element-wise at the specified coordinates in the global tensor. **Constraints:** The source tensor must be 128-byte aligned in shared memory. **Parameters:** * ​reduction\_kind ([`ReduceOp`](/mojo/std/gpu/memory/memory/ReduceOp)): The type of reduction operation to perform (e.g., ADD, MIN, MAX). This determines how values are combined during the reduction. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory containing the data to be reduced. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): The 2D coordinates in the destination tensor where the reduction will be applied. ### `commit_group` `commit_group(self)` Commits all prior initiated but uncommitted TMA instructions into a group. This function behaves the same as `cp_async_bulk_commit_group`, which creates a synchronization point for bulk TMA transfer. ### `wait_group` `wait_group[n: Int = 0](self)` Wait for the completion of asynchronous copy until a specified number of groups are waiting. This function behaves the same as `cp_async_bulk_wait_group`, which causes the executing thread to wait until a specified number of the most recent TMA copy are pending. **Parameters:** * ​n ([`Int`](/mojo/std/builtin/int/Int)): The number of pending groups left. ### `smem_tensormap_init` `smem_tensormap_init(self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, origin, address_space=AddressSpace.SHARED])` Initializes a TMA descriptor in shared memory from this tensor tile's descriptor. This method copies the TMA descriptor from global memory to shared memory, allowing for faster access during kernel execution. The descriptor is copied in 16-byte chunks using asynchronous copy operations for efficiency. Note: * Only one thread should call this method to avoid race conditions * The descriptor is copied in 8 chunks of 16 bytes each (total 128 bytes) **Args:** * ​smem\_tma\_descriptor\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the location in shared memory where the descriptor will be stored. Must be properly aligned. ### `replace_tensormap_global_address_in_gmem` `replace_tensormap_global_address_in_gmem[_dtype: DType](self, src_ptr: UnsafePointer[Scalar[_dtype], origin])` Replaces the global memory address in the TMA descriptor stored in global memory. This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies the descriptor in global memory directly. Note: A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the new source tensor. **Args:** * ​src\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The new source tensor whose address will replace the current one in the descriptor. Must have compatible layout with the original tensor. ### `tensormap_fence_acquire` `tensormap_fence_acquire(self)` Establishes a memory fence for TMA operations with acquire semantics. This method ensures proper ordering of memory operations by creating a barrier that prevents subsequent TMA operations from executing before prior operations have completed. It is particularly important when reading from a descriptor that might have been modified by other threads or processes. The acquire semantics ensure that all memory operations after this fence will observe any modifications made to the descriptor before the fence. Notes: * The entire warp must call this function as the instruction is warp-aligned. * Typically used in pairs with `tensormap_fence_release` for proper synchronization. ### `tensormap_fence_release` `tensormap_fence_release(self)` Establishes a memory fence for TMA operations with release semantics. This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is particularly important when modifying a TMA descriptor in global memory that might be read by other threads or processes. The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence. Notes: * Typically used after modifying a tensormap descriptor in global memory. * Often paired with `tensormap_fence_acquire` for proper synchronization. ### `replace_tensormap_global_address_in_shared_mem` `replace_tensormap_global_address_in_shared_mem[_dtype: DType](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, origin, address_space=AddressSpace.SHARED], src_ptr: UnsafePointer[Scalar[_dtype], origin])` Replaces the global memory address in the TMA descriptor stored in shared memory. This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies a descriptor that has been previously copied to shared memory. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. * Typically used with descriptors previously initialized with `smem_tensormap_init`. **Parameters:** * ​\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the new source tensor. **Args:** * ​smem\_tma\_descriptor\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the TMA descriptor in shared memory that will be modified. * ​src\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): The new source tensor whose address will replace the current one in the descriptor. ### `tensormap_cp_fence_release` `tensormap_cp_fence_release(self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, origin, address_space=AddressSpace.SHARED])` Establishes a memory fence for TMA operations with release semantics for shared memory descriptors. This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is specifically designed for synchronizing between global memory and shared memory TMA descriptors. The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence. Notes: * The entire warp must call this function as the instruction is warp-aligned * Typically used after modifying a tensormap descriptor in shared memory * More specialized than the general `tensormap_fence_release` for cross-memory space synchronization **Args:** * ​smem\_tma\_descriptor\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the TMA descriptor in shared memory that is being synchronized with the global memory descriptor. ### `replace_tensormap_global_dim_strides_in_shared_mem` `replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, only_update_dim_0: Bool, /, *, rank: Int](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, origin, address_space=AddressSpace.SHARED], gmem_dims: IndexList[rank], gmem_strides: IndexList[rank])` Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5. This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension (dim 0) is updated, then updating strides can be skipped. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the new source tensor. * ​only\_update\_dim\_0 ([`Bool`](/mojo/std/builtin/bool/Bool)): If true, only the first dimension (dim 0) is updated with updating strides. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the tensor. **Args:** * ​smem\_tma\_descriptor\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the TMA descriptor in shared memory that will be modified. * ​gmem\_dims ([`IndexList`](/mojo/std/utils/index_/IndexList)): The global dimensions of the tensor to be updated. * ​gmem\_strides ([`IndexList`](/mojo/std/utils/index_/IndexList)): The global strides of the tensor to be updated. `replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, tensor_rank: Int, dim_idx: Int](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, origin, address_space=AddressSpace.SHARED], dim_value: UInt32, dim_stride: Optional[UInt64] = None)` Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5. This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension is updated, then updating strides can be skipped. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the source tensor in GMEM. * ​tensor\_rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of the source tensor in GMEM. * ​dim\_idx ([`Int`](/mojo/std/builtin/int/Int)): The index of the dimension to be updated in the TMA descriptor with the provided dimension and stride values at runtime. **Args:** * ​smem\_tma\_descriptor\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the TMA descriptor in shared memory that will be modified. * ​dim\_value ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The new dimension value to be set. * ​dim\_stride ([`Optional`](/mojo/std/collections/optional/Optional)): The new stride value to be set.
--- ## TMATensorTileArray
`@register_passable(trivial)` `struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, cta_tile_layout: Layout, desc_layout: Layout]` An array of TMA descripotr. ## Parameters * ​num\_of\_tensormaps ([`Int`](/mojo/std/builtin/int/Int)): Int The number of TMA descriptors aka tensor map. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType The data type of the tensor elements. * ​cta\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the tile in shared memory, typically specified as row\_major. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA. ## Fields * ​tensormaps\_ptr (`UnsafePointer[UInt8, MutAnyOrigin]`): A static tuple of pointers to TMA descriptors. This field stores an array of pointers to `TMATensorTile` instances, where each pointer references a TMA descriptor in device memory. The array has a fixed size determined by the num\_of\_tensormaps parameter. The TMA descriptors are used by the GPU hardware to efficiently transfer data between global and shared memory with specific memory access patterns defined by the layouts. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `descriptor_bytes` `comptime descriptor_bytes = 128` Size of the TMA descriptor in bytes. This is a constant value that represents the size of the TMA descriptor in bytes. It is used to calculate the offset of the TMA descriptor in the device memory. ### `device_type` `comptime device_type = TMATensorTileArray[num_of_tensormaps, dtype, cta_tile_layout, desc_layout]` The device-side type representation. ## Methods ### `__init__` `__init__(tensormaps_device: DeviceBuffer[DType.uint8]) -> Self` Initializes a new TMATensorTileArray. **Args:** * ​tensormaps\_device ([`DeviceBuffer`](/mojo/std/gpu/host/device_context/DeviceBuffer)): Device buffer to store TMA descriptors. ### `__getitem__` `__getitem__(self, index: Int) -> UnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout], MutAnyOrigin]` Retrieve a TMA descriptor. **Args:** * ​index ([`Int`](/mojo/std/builtin/int/Int)): Index of the TMA descriptor. **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer): `UnsafePointer` to the `TMATensorTile` at the specified index. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name.
--- ## TMATensorTileIm2col
`struct TMATensorTileIm2col[dtype: DType, layout: Layout, desc_layout: Layout = layout]` TMA tensor tile with im2col coordinate transformation for convolution. This struct enables hardware-accelerated im2col transformation during TMA loads, used for implicit GEMM convolution. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly. The coordinate system uses GEMM-style 2D coordinates: * coords\[0]: K coordinate (indexes into R \* S \* C reduction dimension) * coords\[1]: M coordinate (indexes into batch \* H\_out \* W\_out spatial) Internally: * K is decomposed into (c, r, s) where K = r*S*C + s\*C + c (filter-first, channel-last for NHWC) * M is decomposed into (n, h, w) where M = n*H\_out*W\_out + h\*W\_out + w * 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of tensor elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the tile in shared memory. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the descriptor (may differ for WGMMA compatibility). ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor encoding im2col transformation parameters. * ​out\_height (`UInt32`): Output height (H\_out) for M coordinate decomposition. * ​out\_width (`UInt32`): Output width (W\_out) for M coordinate decomposition. * ​filter\_h (`UInt32`): Filter height (R) for K coordinate decomposition. * ​filter\_w (`UInt32`): Filter width (S) for K coordinate decomposition. * ​in\_channels (`UInt32`): Input channels (C) for K coordinate decomposition. * ​lower\_corner\_h (`Int32`): Lower corner offset for height (H dimension) - matches CUTLASS ArithmeticTupleIterator pattern. * ​lower\_corner\_w (`Int32`): Lower corner offset for width (W dimension) - matches CUTLASS ArithmeticTupleIterator pattern. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`DevicePassable`](/mojo/std/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = TMATensorTileIm2col[dtype, layout, desc_layout]` The device-side type representation. ## Methods ### `__init__` `__init__(out self, descriptor: TMADescriptor, out_height: UInt32, out_width: UInt32, filter_h: UInt32, filter_w: UInt32, in_channels: UInt32, lower_corner_h: Int32 = 0, lower_corner_w: Int32 = 0)` Initializes with the provided TMA im2col descriptor and dimensions. **Args:** * ​descriptor ([`TMADescriptor`](/mojo/std/gpu/host/nvidia/tma/TMADescriptor)): The TMA descriptor that encodes im2col transformation. * ​out\_height ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Output height (H\_out) for M coordinate decomposition. * ​out\_width ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Output width (W\_out) for M coordinate decomposition. * ​filter\_h ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Filter height (R) for K coordinate decomposition. * ​filter\_w ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Filter width (S) for K coordinate decomposition. * ​in\_channels ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Input channels (C) for K coordinate decomposition. * ​lower\_corner\_h ([`Int32`](/mojo/std/builtin/simd/#int32)): Lower corner offset for H dimension (matches CUTLASS pattern). * ​lower\_corner\_w ([`Int32`](/mojo/std/builtin/simd/#int32)): Lower corner offset for W dimension (matches CUTLASS pattern). ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy initializes from another instance. **Args:** * ​other (`Self`): The other instance to copy from. ### `get_type_name` `static get_type_name() -> String` Gets this type's name for error messages. **Returns:** `String`: This type's name. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache. ### `async_copy` `async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])` Schedules an asynchronous im2col TMA load. Uses 2D GEMM-style coordinates: * coords\[0]: K coordinate (indexes into C \* R \* S reduction dimension) * coords\[1]: M coordinate (indexes into batch \* H\_out \* W\_out spatial) Internally: * K is decomposed into (c, r, s) where K = c*R*S + r\*S + s * M is decomposed into (n, h, w) where M = n*H\_out*W\_out + h\*W\_out + w * 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction. Note: The cta\_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta\_group::2 PTX format. This is consistent with CUTLASS which only provides SM100\_TMA\_2SM\_LOAD\_IM2COL (no cta\_group::1 variant for im2col). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size for TMA operations. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the TMA load. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in shared memory. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): GEMM coordinates (k\_coord, m\_coord). `async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])` Schedules an asynchronous im2col TMA load. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). Uses 2D GEMM-style coordinates: * coords\[0]: K coordinate (indexes into C \* R \* S reduction dimension) * coords\[1]: M coordinate (indexes into batch \* H\_out \* W\_out spatial) Internally: * K is decomposed into (c, r, s) where K = c*R*S + r\*S + s * M is decomposed into (n, h, w) where M = n*H\_out*W\_out + h\*W\_out + w * 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction. Note: Uses cta\_group=1 (SM90-style TMA) for single-CTA clusters. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size for TMA operations. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the TMA load. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): GEMM coordinates (k\_coord, m\_coord). ### `async_multicast_load` `async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous im2col TMA load with multicast. Uses 2D GEMM-style coordinates: * coords\[0]: K coordinate (indexes into C \* R \* S reduction dimension) * coords\[1]: M coordinate (indexes into batch \* H\_out \* W\_out spatial) Internally: * K is decomposed into (c, r, s) where K = c*R*S + r\*S + s * M is decomposed into (n, h, w) where M = n*H\_out*W\_out + h\*W\_out + w * 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast. Note: The cta\_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta\_group::2 PTX format. This is consistent with CUTLASS which only provides SM100\_TMA\_2SM\_LOAD\_IM2COL\_MULTICAST (no cta\_group::1 variant). **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size for TMA operations. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the TMA load. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in shared memory. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): GEMM coordinates (k\_coord, m\_coord). * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Bitmask specifying target CTAs for multicast. `async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous im2col TMA load with multicast. TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment). Uses 2D GEMM-style coordinates: * coords\[0]: K coordinate (indexes into C \* R \* S reduction dimension) * coords\[1]: M coordinate (indexes into batch \* H\_out \* W\_out spatial) Internally: * K is decomposed into (c, r, s) where K = c*R*S + r\*S + s * M is decomposed into (n, h, w) where M = n*H\_out*W\_out + h\*W\_out + w * 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast. Note: Uses cta\_group=1 (SM90-style TMA) for single-CTA clusters. **Parameters:** * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size for TMA operations. * ​eviction\_policy ([`CacheEviction`](/mojo/std/gpu/memory/memory/CacheEviction)): Cache eviction policy for the TMA load. **Args:** * ​dst ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): TileTensor in shared memory where data will be copied. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): GEMM coordinates (k\_coord, m\_coord). * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Bitmask specifying target CTAs for multicast.
--- ## create_split_tma
`create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], origin], runtime_dim0: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])` Creates a TMA tensor tile assuming that the first dimension in global memory has `UNKNOWN_VALUE`. This function creates a `TMATensorTile` that optionally splits the last dimension of the tensor into multiples of swizzle granularity. This functionality is currently disabled because it was not found to improve performance. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the tile in shared memory. * ​gmem\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the global memory tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the global memory tensor data. * ​runtime\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): The runtime size of the first dimension of the global tensor. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): The resulting TMA tensor tile with split layout. **Raises:** If TMA descriptor creation fails. `create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin], runtime_dim0: Int, runtime_dim1: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])` Creates a TMA tensor tile assuming that the first two dimensions in global memory has `UNKNOWN_VALUE`. This function creates a `TMATensorTile` that optionally splits the last dimension of the tensor into multiples of swizzle granularity. This functionality is currently disabled because it was not found to improve performance. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the tile in shared memory. * ​gmem\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the global memory tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to the global memory tensor data. * ​runtime\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): The runtime size of the first dimension of the global tensor. * ​runtime\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): The runtime size of the second dimension of the global tensor. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): The resulting TMA tensor tile with split layout. **Raises:** If TMA descriptor creation fails.
--- ## create_tensor_tile
`create_tensor_tile[dtype: DType, rank: Int, //, tile_shape: IndexList[rank], /, k_major_tma: Bool = True, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[rank, DType.int64, Int](0), tile_shape.__getitem__[rank, DType.int64, Int](1)), __desc_layout: Layout = _tma_desc_tile_layout[dtype, rank, tile_shape, swizzle_mode]()](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TMATensorTile[dtype, __tile_layout, __desc_layout, k_major_tma]` Creates a `TMATensorTile` with advanced configuration options for 2D, 3D, 4D, or 5D tensors. This overload provides more control over the TMA descriptor creation, allowing specification of data type, rank, and layout orientation. It supports 2D, 3D, 4D, and 5D tensors and provides fine-grained control over the memory access patterns. **Constraints:** * Only supports 2D, 3D, 4D, and 5D tensors (rank must be 2, 3, 4, or 5). * For non-SWIZZLE\_NONE modes, the K dimension size in bytes must be a multiple of the swizzle mode's byte size. * For MN-major layout, only SWIZZLE\_128B is supported. * For 3D, 4D, and 5D tensors, only K-major layout is supported. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType The data type of the tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Int The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​tile\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): IndexList\[rank] The shape of the tile to be transferred. * ​k\_major\_tma ([`Bool`](/mojo/std/builtin/bool/Bool)): Bool = True Whether the tma should copy desc into shared memory following a column-major (if `True`) or row-major (if `False`) pattern. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): TensorMapSwizzle = TensorMapSwizzle.SWIZZLE\_NONE The swizzling mode to use for memory access optimization. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = Layout.row\_major(tile\_shape\[0], tile\_shape\[1]) Internal parameter for the tile layout in shared memory. * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = \_tma\_desc\_tile\_layout\[...] Internal parameter for the descriptor layout, which may differ from the tile layout to accommodate hardware requirements. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): DeviceContext The CUDA device context used to create the TMA descriptor. * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor\[dtype, ...] The source tensor from which data will be transferred. This defines the global memory layout and must match the specified data type. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified parameters, ready for use in asynchronous data transfer operations. `create_tensor_tile[dtype: DType, rank: Int, //, tile_shape: IndexList[rank], /, k_major_tma: Bool = True, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[rank, DType.int64, Int](0), tile_shape.__getitem__[rank, DType.int64, Int](1)), __desc_layout: Layout = _tma_desc_tile_layout[dtype, rank, tile_shape, swizzle_mode]()](ctx: DeviceContext, tensor: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]) -> TMATensorTile[dtype, __tile_layout, __desc_layout, k_major_tma]` Creates a `TMATensorTile` from a TileTensor. This overload accepts a TileTensor instead of LayoutTensor, enabling use with the new coordinate-based tensor abstraction. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​tile\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The shape of the tile to be transferred. * ​k\_major\_tma ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the TMA should use column-major pattern. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Internal parameter for the tile layout. * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Internal parameter for the descriptor layout. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The CUDA device context. * ​tensor ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The source TileTensor. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured for the given tensor.
--- ## create_tensor_tile_im2col
`create_tensor_tile_im2col[dtype: DType, tile_shape: IndexList[2], swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[2, DType.int64, Int](0), tile_shape.__getitem__[2, DType.int64, Int](1)), __desc_layout: Layout = _im2col_desc_tile_layout[dtype, tile_shape, swizzle_mode]()](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lower_corner_h: Int, lower_corner_w: Int, upper_corner_h: Int, upper_corner_w: Int, out_height: Int, out_width: Int, filter_h: Int, filter_w: Int) -> TMATensorTileIm2col[dtype, __tile_layout, __desc_layout]` Creates a TMA tensor tile with im2col transformation for 2D convolution. This factory function creates a TMA descriptor that performs hardware im2col transformation during loads. The descriptor encodes the convolution geometry and the TMA hardware computes addresses on-the-fly. For im2col TMA, each transaction loads one output pixel with multiple channels. This follows CUTLASS's approach where: * pixels\_per\_column = 1 (one pixel per TMA transaction) * channels\_per\_pixel = min(K\_tile, swizzle\_width) (contiguous channels) Note: For stride=1, dilation=1 convolution with padding (following CUTLASS convention): * lower\_corner\_h = -pad\_h * lower\_corner\_w = -pad\_w * upper\_corner\_h = pad\_h - (filter\_h - 1) * upper\_corner\_w = pad\_w - (filter\_w - 1) The filter offsets passed to the PTX instruction range from 0 to (filter\_size - 1) and are added to lower\_corner to compute actual input coordinates. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of tensor elements. * ​tile\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): Shape `[M_tile, K_tile]` for the GEMM tile. * M\_tile: Number of output pixels (batch \* H\_out \* W\_out slice). * K\_tile: Number of channels (C\_in \* R \* S slice for filter). * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory swizzling pattern. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Internal layout parameter (full tile shape). * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Internal descriptor layout parameter (TMA box shape). **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The CUDA device context. * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The 4D activation tensor in NHWC layout. * ​lower\_corner\_h ([`Int`](/mojo/std/builtin/int/Int)): Lower corner offset for height (negative for padding). * ​lower\_corner\_w ([`Int`](/mojo/std/builtin/int/Int)): Lower corner offset for width (negative for padding). * ​upper\_corner\_h ([`Int`](/mojo/std/builtin/int/Int)): Upper corner offset for height. * ​upper\_corner\_w ([`Int`](/mojo/std/builtin/int/Int)): Upper corner offset for width. * ​out\_height ([`Int`](/mojo/std/builtin/int/Int)): Output height (H\_out) for M coordinate decomposition. * ​out\_width ([`Int`](/mojo/std/builtin/int/Int)): Output width (W\_out) for M coordinate decomposition. * ​filter\_h ([`Int`](/mojo/std/builtin/int/Int)): Filter height (R) for K coordinate decomposition. * ​filter\_w ([`Int`](/mojo/std/builtin/int/Int)): Filter width (S) for K coordinate decomposition. **Returns:** `TMATensorTileIm2col`: A TMATensorTileIm2col configured for im2col loads. **Raises:** Error if TMA descriptor creation fails.
--- ## create_tma_tile
`create_tma_tile[*tile_sizes: Int, *, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TMATensorTile[dtype, Layout.row_major(_to_int_tuple[tile_sizes]())]` Creates a `TMATensorTile` with specified tile dimensions and swizzle mode. This function creates a hardware-accelerated Tensor Memory Access (TMA) descriptor for efficient asynchronous data transfers between global memory and shared memory. It configures the tile dimensions and memory access patterns based on the provided parameters. **Constraints:** * The last dimension's size in bytes must not exceed the swizzle mode's byte limit (32B for SWIZZLE\_32B, 64B for SWIZZLE\_64B, 128B for SWIZZLE\_128B). * Only supports 2D tensors in this overload. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/std/builtin/int/Int)): The dimensions of the tile to be transferred. For 2D tensors, this should be \[height, width]. The dimensions determine the shape of data transferred in each TMA operation. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor from which data will be transferred. This defines the global memory layout and data type. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified tile dimensions and swizzle mode, ready for use in asynchronous data transfer operations.
--- ## create_tma_tile_template
`create_tma_tile_template[dtype: DType, rank: Int, tile_shape: IndexList[rank], /, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[rank, DType.int64, Int](0), tile_shape.__getitem__[rank, DType.int64, Int](1)), __desc_layout: Layout = _tma_desc_tile_layout[dtype, rank, tile_shape, swizzle_mode]()]() -> TMATensorTile[dtype, __tile_layout, __desc_layout]` Same as create\_tma\_tile expect the descriptor is only a placeholder or a template for later replacement. specification of data type, rank, and layout orientation. It supports both 2D and 3D tensors and provides fine-grained control over the memory access patterns. **Constraints:** * Only supports 2D and 3D tensors (rank must be 2 or 3). * For non-SWIZZLE\_NONE modes, the K dimension size in bytes must be a multiple of the swizzle mode's byte size. * For MN-major layout, only SWIZZLE\_128B is supported. * For 3D tensors, only K-major layout is supported. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType The data type of the tensor elements. * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Int The dimensionality of the tensor (must be 2 or 3). * ​tile\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): IndexList\[rank] The shape of the tile to be transferred. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): TensorMapSwizzle = TensorMapSwizzle.SWIZZLE\_NONE The swizzling mode to use for memory access optimization. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = Layout.row\_major(tile\_shape\[0], tile\_shape\[1]) Internal parameter for the tile layout in shared memory. * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = \_tma\_desc\_tile\_layout\[...] Internal parameter for the descriptor layout, which may differ from the tile layout to accommodate hardware requirements. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified parameters, ready for use in asynchronous data transfer operations.
--- ## tma_async
Tensor Memory Accelerator (TMA) Asynchronous Operations Module Provides high-performance abstractions for NVIDIA's Tensor Memory Accelerator (TMA), enabling efficient asynchronous data movement between global and shared memory in GPU kernels. It is designed for use with NVIDIA Hopper architecture and newer GPUs that support TMA instructions. ## Key Components: * `TMATensorTile`: Core struct that encapsulates a TMA descriptor for efficient data transfers between global and shared memory with various access patterns and optimizations. * `SharedMemBarrier`: Synchronization primitive for coordinating asynchronous TMA operations, ensuring data transfers complete before dependent operations begin. * `PipelineState`: Helper struct for managing multi-stage pipeline execution with circular buffer semantics, enabling efficient double or triple buffering techniques. * `create_tma_tile`: Factory functions for creating optimized `TMATensorTile` instances with various configurations for different tensor shapes and memory access patterns. ## `comptime` values ### `SplitLastDimTMATensorTile` `comptime SplitLastDimTMATensorTile[rank: Int, //, dtype: DType, smem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle] = TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)]` A specialized TMA tensor tile type alias that handles layouts where the last dimension is split based on swizzle granularity for optimal memory access patterns. The current behavior is to not actually split the last dimension. #### Parameters * ​rank ([`Int`](/std/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/std/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/std/utils/index_/IndexList)): The shape of the tile in shared memory. The last dimension will be padded if necessary to align with the swizzle granularity. * ​swizzle\_mode ([`TensorMapSwizzle`](/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. Determines the granularity at which the last dimension is split or padded. ## Structs * [​`PipelineState`](./PipelineState): Manages state for a multi-stage pipeline with circular buffer semantics. * [​`RaggedTensorMap`](./RaggedTensorMap): Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample. * [​`RaggedTMA3DTile`](./RaggedTMA3DTile): Creates a TMA descriptor for loading/storing from ragged 3D arrays with a ragged leading dimension. This loads 2D tiles, indexing into the middle dim. When using this loads, it is essential that at least `BM * stride` space has been allocated in front of the gmem pointer, otherwise `CUDA_ERROR_ILLEGAL_ADDRESS` may result. * [​`SharedMemBarrier`](./SharedMemBarrier): A hardware-accelerated synchronization primitive for GPU shared memory operations. * [​`TMATensorTile`](./TMATensorTile): A hardware-accelerated tensor memory access (TMA) tile for efficient asynchronous data movement. * [​`TMATensorTileArray`](./TMATensorTileArray): An array of TMA descripotr. * [​`TMATensorTileIm2col`](./TMATensorTileIm2col): TMA tensor tile with im2col coordinate transformation for convolution. ## Functions * [​`create_split_tma`](./create_split_tma): Creates a TMA tensor tile assuming that the first dimension in global memory has `UNKNOWN_VALUE`. * [​`create_tensor_tile`](./create_tensor_tile): Creates a `TMATensorTile` with advanced configuration options for 2D, 3D, 4D, or 5D tensors. * [​`create_tensor_tile_im2col`](./create_tensor_tile_im2col): Creates a TMA tensor tile with im2col transformation for 2D convolution. * [​`create_tma_tile`](./create_tma_tile): Creates a `TMATensorTile` with specified tile dimensions and swizzle mode. * [​`create_tma_tile_template`](./create_tma_tile_template): Same as create\_tma\_tile expect the descriptor is only a placeholder or a template for later replacement.
--- ## accumulate
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]`
--- ## dot_at_b
`dot_at_b(c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## dot_at_b_impl
`dot_at_b_impl(c: LayoutTensor[DType.float32, Layout.row_major(16, 16), MutAnyOrigin], a: LayoutTensor[DType.float32, Layout.row_major(16, 16), ImmutAnyOrigin], b: LayoutTensor[DType.float32, Layout.row_major(16, 16), ImmutAnyOrigin])` `dot_at_b_impl(c: LayoutTensor[DType.float16, Layout.row_major(32, 32), MutAnyOrigin], a: LayoutTensor[DType.float16, Layout.row_major(32, 32), ImmutAnyOrigin], b: LayoutTensor[DType.float16, Layout.row_major(32, 32), ImmutAnyOrigin])`
--- ## extrx
`extrx(gpr: Int)` Extracts a row or moves it to x, result in amx0.
--- ## extry
`extry(gpr: Int)` Extracts a row or moves it to y, result in amx0.
--- ## fma
`fma[mode: StringSlice[StaticConstantOrigin], dtype: DType](z_row_index: Int, x_row_index: Int, y_row_index: Int, clear_z: Bool)`
--- ## fma16
`fma16(gpr: Int)` Float16 matrix multiply and subtract.
--- ## fma32
`fma32(gpr: Int)` Float32 matrix multiply and add.
--- ## fma64
`fma64(gpr: Int)` Float64 matrix multiply and add.
--- ## fms16
`fms16(gpr: Int)` Float16 matrix multiply and add.
--- ## fsm32
`fsm32(gpr: Int)` Float32 matrix multiply and subtract.
--- ## fsm64
`fsm64(gpr: Int)` Float64 matrix multiply and subtract.
--- ## genlut
`genlut(gpr: Int)`
--- ## apple_amx_intrinsics
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`dot_at_b`](./dot_at_b): * [​`dot_at_b_impl`](./dot_at_b_impl): * [​`extrx`](./extrx): Extracts a row or moves it to x, result in amx0. * [​`extry`](./extry): Extracts a row or moves it to y, result in amx0. * [​`fma`](./fma): * [​`fma16`](./fma16): Float16 matrix multiply and subtract. * [​`fma32`](./fma32): Float32 matrix multiply and add. * [​`fma64`](./fma64): Float64 matrix multiply and add. * [​`fms16`](./fms16): Float16 matrix multiply and add. * [​`fsm32`](./fsm32): Float32 matrix multiply and subtract. * [​`fsm64`](./fsm64): Float64 matrix multiply and subtract. * [​`genlut`](./genlut): * [​`ldx`](./ldx): * [​`ldy`](./ldy): * [​`ldz`](./ldz): * [​`ldzi`](./ldzi): * [​`load_z`](./load_z): * [​`mac16`](./mac16): SI16 matrix multiply and add. * [​`matfp`](./matfp): Float16 matrix multiply. * [​`max_int__`](./max_int__): UI16 matrix multiply. * [​`read_x`](./read_x): * [​`read_y`](./read_y): * [​`store_x`](./store_x): * [​`store_y`](./store_y): * [​`store_z`](./store_z): * [​`stx`](./stx): * [​`sty`](./sty): * [​`stz`](./stz): * [​`stzi`](./stzi): * [​`transpose_z_to_x_or_y`](./transpose_z_to_x_or_y): * [​`vec_int__`](./vec_int__): Horizontal ui16 multiply `z0[i] += x0[i] + y0[i]`. * [​`vecfp`](./vecfp): Horizontal float16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## ldx
`ldx(gpr: Int)`
--- ## ldy
`ldy(gpr: Int)`
--- ## ldz
`ldz(gpr: Int)`
--- ## ldzi
`ldzi(gpr: Int)`
--- ## load_z
`load_z[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## mac16
`mac16(gpr: Int)` SI16 matrix multiply and add.
--- ## matfp
`matfp(gpr: Int)` Float16 matrix multiply.
--- ## max_int__
`max_int__(gpr: Int)` UI16 matrix multiply.
--- ## read_x
`read_x[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## read_y
`read_y[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_x
`store_x[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_y
`store_y[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_z
`store_z[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## stx
`stx(gpr: Int)`
--- ## sty
`sty(gpr: Int)`
--- ## stz
`stz(gpr: Int)`
--- ## stzi
`stzi(gpr: Int)`
--- ## transpose_z_to_x_or_y
`transpose_z_to_x_or_y[destination: StringSlice[StaticConstantOrigin], dtype: DType](z_col_index: Int, xy_row_index: Int, z_row_suboffset: Int)`
--- ## vec_int__
`vec_int__(gpr: Int)` Horizontal ui16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## vecfp
`vecfp(gpr: Int)` Horizontal float16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## cpu
Provides cpu architecture specific utility functions. ## Modules * [​`apple_amx_intrinsics`](./apple_amx_intrinsics/): * [​`neon_intrinsics`](./neon_intrinsics/): * [​`vnni_intrinsics`](./vnni_intrinsics/):
--- ## neon_intrinsics
--- ## dot_i16_to_i32_AVX2
`dot_i16_to_i32_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, (width * 2)], b: SIMD[b_type, (width * 2)]) -> SIMD[c_type, width]` The dot product of the two words in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int16 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int16 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i16_to_i32_x86
`dot_i16_to_i32_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, (width * 2)], b: SIMD[b_type, (width * 2)]) -> SIMD[c_type, width]` The dot product of the two words in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int16 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int16 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_AVX2
`dot_i8_to_i32_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_saturated_AVX2
`dot_i8_to_i32_saturated_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,127] not \[0, 255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_saturated_x86
`dot_i8_to_i32_saturated_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,127] not \[0, 255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_x86
`dot_i8_to_i32_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/std/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/std/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## vnni_intrinsics
## Functions * [​`dot_i16_to_i32_AVX2`](./dot_i16_to_i32_AVX2): The dot product of the two words in each int32 element of a and b plus a int32 from src. * [​`dot_i16_to_i32_x86`](./dot_i16_to_i32_x86): The dot product of the two words in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`dot_i8_to_i32_AVX2`](./dot_i8_to_i32_AVX2): The dot product of the four bytes in each int32 element of a and b plus a int32 from src. * [​`dot_i8_to_i32_saturated_AVX2`](./dot_i8_to_i32_saturated_AVX2): The dot product of the four bytes in each int32 element of a and b plus a int32 from src. * [​`dot_i8_to_i32_saturated_x86`](./dot_i8_to_i32_saturated_x86): The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`dot_i8_to_i32_x86`](./dot_i8_to_i32_x86): The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`pmaddubs`](./pmaddubs): * [​`pmaddw`](./pmaddw): * [​`vpdpbusd`](./vpdpbusd): * [​`vpdpbusds`](./vpdpbusds): * [​`vpdpwssd`](./vpdpwssd): * [​`vpdpwssds`](./vpdpwssds):
--- ## pmaddubs
`pmaddubs[width: Int](a: SIMD[DType.int32, width], b: SIMD[DType.int32, width]) -> SIMD[DType.int32, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## pmaddw
`pmaddw[width: Int](a: SIMD[DType.int32, width], b: SIMD[DType.int32, width]) -> SIMD[DType.int32, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## vpdpbusd
`vpdpbusd[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## vpdpbusds
`vpdpbusds[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## vpdpwssd
`vpdpwssd[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, (width * 2)], b: SIMD[b_type, (width * 2)]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## vpdpwssds
`vpdpwssds[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, (width * 2)], b: SIMD[b_type, (width * 2)]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## arch
Provides architecture specific utility functions. ## Packages * [​`cpu`](./cpu/): Provides cpu architecture specific utility functions. * [​`sm100`](./sm100/): Provides Nvidia Blackwell architecture specific utility functions.
--- ## sm100
Provides Nvidia Blackwell architecture specific utility functions. ## Modules * [​`mma`](./mma/):
--- ## Major
`@register_passable(trivial)` `struct Major` ## Fields * ​val (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `K` `comptime K = Major(0)` ### `MN` `comptime MN = Major(1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## MmaOpSM100_BlockScaled_SS
`@register_passable(trivial)` `struct MmaOpSM100_BlockScaled_SS[c_type: DType, a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, scaling_kind: UMMAKind, block_tile_shape: IndexList[3], mma_shape: IndexList[3], /, *, accum_type: DType = DType.float32, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = False]` ## Fields * ​idesc (`UMMAInsDescriptor[scaling_kind]`): * ​mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `mma` `mma(self, a: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sfa_smem: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sfb_smem: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tmem: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, init_c: Bool)` MMA input tiles. The layout assumes that coalesce(A) has shape (bm, sw\_k, num\_sw\_k), we currently assumes bm = mma\_m. In future, we can tile it to (mma\_m, sw\_k, num\_sw\_k, num\_mma\_m) The same logic applies to matrix B. `mma(self, a: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], b: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], sfa_smem: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], sfb_smem: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], c_tmem: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, init_c: Bool)` TileTensor overload for block-scaled MMA input tiles. This overload accepts TileTensor directly for A, B, and scale factor tiles. The layout is extracted from TileTensor's compile-time type parameters (shape\_types, stride\_types) using direct VariadicType extraction for fast compile times. ### `commit` `commit(self, ptr_mbar: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED, origin=origin])` ### `wait` `wait(self)` ### `copy_sf_to_tmem` `copy_sf_to_tmem[sf_dtype: DType, sf_smem_layout: Layout, TILE_MN: Int, tile_k_idx: Int](self, sf_smem: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sf_tmem: UInt32)`
--- ## MmaOpSM100_SS
`@register_passable(trivial)` `struct MmaOpSM100_SS[c_type: DType, a_type: DType, b_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], /, *, accum_type: DType = DType.float32, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = False]` ## Fields * ​idesc (`UMMAInsDescriptor[MmaOpSM100_SS._get_umma_kind[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type, cta_group, cluster_shape, a_swizzle, b_swizzle, transpose_b, a_type]()]`): * ​mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Defaultable`](/mojo/std/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `mma` `mma(self, a: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tmem: UInt32, init_c: Bool)` MMA input tiles. The layout assumes that coalesce(A) has shape (bm, sw\_k, num\_sw\_k), we currently assumes bm = mma\_m. In future, we can tile it to (mma\_m, sw\_k, num\_sw\_k, num\_mma\_m) The same logic applies to matrix B. `mma(self, a: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], b: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], c_tmem: UInt32, init_c: Bool)` TileTensor overload for MMA input tiles. This overload accepts TileTensor directly. The layout is extracted from TileTensor's compile-time type parameters (shape\_types, stride\_types). ### `commit` `commit(self, ptr_mbar: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED, origin=origin])` ### `wait` `wait(self)`
--- ## extract_first_2_modes
`extract_first_2_modes[l: Layout]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## mma
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Major`](./Major): * [​`MmaOpSM100_BlockScaled_SS`](./MmaOpSM100_BlockScaled_SS): * [​`MmaOpSM100_SS`](./MmaOpSM100_SS): ## Functions * [​`extract_first_2_modes`](./extract_first_2_modes): * [​`max_contiguous_tile_shape`](./max_contiguous_tile_shape): Returns the maximum shape of a tile that's contiguous in memory for mma op. This is used to create TMA descriptor. * [​`smem_descriptor`](./smem_descriptor):
--- ## max_contiguous_tile_shape
`max_contiguous_tile_shape[rank: Int, //, dtype: DType, tile_shape: IndexList[rank], /, *, major: Major = Major.K, swizzle_mode: SwizzleMode = SwizzleMode.NONE]() -> IntTuple` Returns the maximum shape of a tile that's contiguous in memory for mma op. This is used to create TMA descriptor. **Returns:** `IntTuple`
--- ## smem_descriptor
`smem_descriptor[dtype: DType, //, *, BMN: Int, BK: Int, swizzle_mode: TensorMapSwizzle, is_k_major: Bool](ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> MMASmemDescriptorPair` **Returns:** `MMASmemDescriptorPair`
--- ## batched_matmul
`batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_a: Bool, transpose_b: Bool, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, rank, origin, shape, strides], a_buf: NDBuffer[a_type, rank, origin, shape, strides], b_buf: NDBuffer[b_type, rank, origin, shape, strides], *, context: DeviceContextPtr = DeviceContextPtr())` `batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_b: Bool, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, rank, origin, shape, strides], a_buf: NDBuffer[a_type, rank, origin, shape, strides], b_buf: NDBuffer[b_type, rank, origin, shape, strides], *, context: DeviceContextPtr = DeviceContextPtr())`
--- ## batched_matmul_dynamic_scaled_fp8
`batched_matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[c_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## batched_matmul_dynamic_scaled_fp8_naive
`batched_matmul_dynamic_scaled_fp8_naive[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scales_granularity_mnk: IndexList[3], transpose_b: Bool = False](c_: LayoutTensor[c_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_: LayoutTensor[a_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_: LayoutTensor[b_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales_: LayoutTensor[a_scales_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales_: LayoutTensor[b_scales_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## batched_matmul_kernel_gpu
`batched_matmul_kernel_gpu[c_type: DType, a_type: DType, b_type: DType, CTensorType: TensorLayout, ATensorType: TensorLayout, BTensorType: TensorLayout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c_tensor: TileTensor[c_type, CTensorType, MutAnyOrigin], a_tensor: TileTensor[a_type, ATensorType, MutAnyOrigin], b_tensor: TileTensor[b_type, BTensorType, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## batched_matmul_shape
`batched_matmul_shape[rank: Int, a_type: DType, b_type: DType, single_thread_blocking_override: Bool](a_buff: NDBuffer[a_type, rank, origin], b_buff: NDBuffer[b_type, rank, origin]) -> IndexList[rank]` Compute the output shape of a `batch_matmul` operation, and assert the inputs are compatible. **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): Rank of the input and output tensors. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the lhs input tensor. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the rhs input tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​a\_buff ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The lhs input tensor. * ​b\_buff ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The rhs input tensor. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## bmm_sm100_blockwise_scaled_fp8
`bmm_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## get_shape_index_list
`get_shape_index_list[rank: Int, dtype: DType, layout: Layout](tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[rank]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## bmm
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`batched_matmul`](./batched_matmul): * [​`batched_matmul_dynamic_scaled_fp8`](./batched_matmul_dynamic_scaled_fp8): * [​`batched_matmul_dynamic_scaled_fp8_naive`](./batched_matmul_dynamic_scaled_fp8_naive): * [​`batched_matmul_kernel_gpu`](./batched_matmul_kernel_gpu): * [​`batched_matmul_shape`](./batched_matmul_shape): Compute the output shape of a `batch_matmul` operation, and assert the inputs are compatible. * [​`bmm_sm100_blockwise_scaled_fp8`](./bmm_sm100_blockwise_scaled_fp8): * [​`get_shape_index_list`](./get_shape_index_list): * [​`naive_batched_matmul_kernel`](./naive_batched_matmul_kernel):
--- ## naive_batched_matmul_kernel
`naive_batched_matmul_kernel[rank: Int, c_type: DType, a_type: DType, b_type: DType, CTensorType: TensorLayout, ATensorType: TensorLayout, BTensorType: TensorLayout, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, accum_type: DType = get_accum_type[c_type]()](c_tensor: TileTensor[c_type, CTensorType, MutAnyOrigin], a_tensor: TileTensor[a_type, ATensorType, MutAnyOrigin], b_tensor: TileTensor[b_type, BTensorType, MutAnyOrigin], c_buff_nd_shape: IndexList[rank])`
--- ## distributed_matmul
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[input_index: Int, dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Functions * [​`matmul_allreduce`](./matmul_allreduce): Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. Split the A or B and C matrices into `num_partitions` submatrices at dimension `partition_dim`. This way we can perform `num_partitions` independent matmul + allreduce kernels, and overlap some of the computation.
--- ## matmul_allreduce
`matmul_allreduce[ngpus: Int, partition_dim: Int, outputs_lambda: elementwise_epilogue_type, a_dtype: DType, b_dtype: DType, out_dtype: DType, a_static_shape: DimList, b_static_shape: DimList, c_static_shape: DimList, out_static_shape: DimList, overlap_with_dpl: Bool = True](a_buffers: InlineArray[NDBuffer[a_dtype, 2, MutAnyOrigin, a_static_shape], ngpus], b_buffers: InlineArray[NDBuffer[b_dtype, 2, MutAnyOrigin, b_static_shape], ngpus], c_temp_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, c_static_shape], ngpus], output_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, out_static_shape], ngpus], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctxs: List[DeviceContext], num_partitions: ValOrDim[dim])` Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. Split the A or B and C matrices into `num_partitions` submatrices at dimension `partition_dim`. This way we can perform `num_partitions` independent matmul + allreduce kernels, and overlap some of the computation. `matmul_allreduce[ngpus: Int, outputs_lambda: elementwise_epilogue_type, a_dtype: DType, b_dtype: DType, out_dtype: DType, a_static_shape: DimList, b_static_shape: DimList, c_static_shape: DimList, out_static_shape: DimList](a_buffers: InlineArray[NDBuffer[a_dtype, 2, MutAnyOrigin, a_static_shape], ngpus], b_buffers: InlineArray[NDBuffer[b_dtype, 2, MutAnyOrigin, b_static_shape], ngpus], c_temp_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, c_static_shape], ngpus], output_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, out_static_shape], ngpus], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctxs: List[DeviceContext])` Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. The implementation might potentially split A / B / C matrices and overlap computation to speedup performance.
--- ## config_in_smem
`config_in_smem[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, //, max_smem: Int](config: MatmulConfig[a_type, b_type, c_type, transpose_b]) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## dual_gemm
`dual_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, binary_lambda_fn: binary_fn_type = swilu, config: Optional[MatmulConfig[a_type, b_type, c_type, transpose_b]] = None, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], ctx: DeviceContext)`
--- ## dual_gemv
`dual_gemv[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], ctx: DeviceContext)`
--- ## dual_gemv_kernel
`dual_gemv_kernel[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, simd_width: Scalar[DType.uint], tile_m: Scalar[DType.uint], tile_n: Scalar[DType.uint], num_threads: Scalar[DType.uint], binary_lambda_fn: binary_fn_type, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type]()](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape])`
--- ## dual_gemm (Dual_gemm)
## `comptime` values ### `binary_fn_type` `comptime binary_fn_type = fn[type: DType, width: Int](SIMD[type, width], SIMD[type, width]) -> SIMD[type, width]` ## Functions * [​`config_in_smem`](./config_in_smem): * [​`dual_gemm`](./dual_gemm): * [​`dual_gemv`](./dual_gemv): * [​`dual_gemv_kernel`](./dual_gemv_kernel): * [​`multistage_dual_gemm`](./multistage_dual_gemm): * [​`multistage_dual_gemm_kernel`](./multistage_dual_gemm_kernel): * [​`multistage_dual_mma`](./multistage_dual_mma): * [​`swilu`](./swilu): * [​`swishGLU`](./swishGLU): Reference: GLU Variants Improve Transformer by Noam Shazeer The implementation follows cutlass, using one kernel invocation and writing to the destination once.
--- ## multistage_dual_gemm
`multistage_dual_gemm[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, //, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin], a: LayoutTensor[a_type, a_layout, origin], b0: LayoutTensor[b_type, b_layout, origin], b1: LayoutTensor[b_type, b_layout, origin], ctx: DeviceContext)` `multistage_dual_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, num_k_partitions: Int = 1](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b0: NDBuffer[b_type, 2, origin, b_shape], b1: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)`
--- ## multistage_dual_gemm_kernel
`multistage_dual_gemm_kernel[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b0: LayoutTensor[b_type, b_layout, MutAnyOrigin], b1: LayoutTensor[b_type, b_layout, MutAnyOrigin])`
--- ## multistage_dual_mma
`multistage_dual_mma[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_smem_layout: Layout, b_type: DType, b_layout: Layout, b_smem_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, num_pipeline_stages: Int, transpose_b: Bool, /, *, swizzle_a: Bool = True, static_num_iters: Dim = Dim(), k_group_size: Scalar[DType.uint] = 1](c0: LayoutTensor[c_type, c_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c1: LayoutTensor[c_type, c_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_iter_arg: LayoutTensorIter[dtype, a_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b0_iter_arg: LayoutTensorIter[b_type, b_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b1_iter_arg: LayoutTensorIter[b_type, b_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], a_smem_iter_arg: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b0_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b1_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], num_iters: Int, /, *, num_b_rows: Optional[Int] = None)`
--- ## swilu
`swilu[dtype: DType, width: Int](x: SIMD[dtype, width], y: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## swishGLU
`swishGLU[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, target: StringSlice[StaticConstantOrigin] = "cpu"](a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], ctx: DeviceContextPtr)` Reference: GLU Variants Improve Transformer by Noam Shazeer The implementation follows cutlass, using one kernel invocation and writing to the destination once.
--- ## block_scaled_matmul
`block_scaled_matmul[c_type: DType, a_type: DType, b_type: DType, scales_dtype: DType, //, *, SF_VECTOR_SIZE: Int, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](c_device: NDBuffer[c_type, 2, MutAnyOrigin, shape], a_device: NDBuffer[a_type, 2, MutAnyOrigin, shape], b_device: NDBuffer[b_type, 2, MutAnyOrigin, shape], a_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], b_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], tensor_sf: Float32, ctx: DeviceContext)`
--- ## block_scaled_matmul_with_epilogue
`block_scaled_matmul_with_epilogue[c_type: DType, a_type: DType, b_type: DType, scales_dtype: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, //, *, SF_VECTOR_SIZE: Int, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[scales_dtype, sfa_layout, MutAnyOrigin], b_scales: LayoutTensor[scales_dtype, sfb_layout, MutAnyOrigin], tensor_sf: Float32, ctx: DeviceContext)` Our sm100 block scaled matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 block scaled matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations.
--- ## block_scales_interleave
`block_scales_interleave[scales_dtype: DType, //, *, SF_VECTOR_SIZE: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], input_scales_device: NDBuffer[scales_dtype, 2, MutAnyOrigin, shape], ctx: DeviceContext)`
--- ## block_scales_interleave_fp4
`block_scales_interleave_fp4[scales_dtype: DType, input_scales_layout: Layout, output_scales_layout: Layout, //, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 1024](ctx: DeviceContext, input_scales: LayoutTensor[scales_dtype, input_scales_layout, MutAnyOrigin], output_scales: LayoutTensor[scales_dtype, output_scales_layout, MutAnyOrigin])`
--- ## block_scales_interleave_fp4_kernel
`block_scales_interleave_fp4_kernel[scales_dtype: DType, input_scales_layout: Layout, output_scales_layout: Layout, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 1024](input_scales: LayoutTensor[scales_dtype, input_scales_layout, MutAnyOrigin], output_scales: LayoutTensor[scales_dtype, output_scales_layout, MutAnyOrigin])`
--- ## fp4_quantization
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`block_scaled_matmul`](./block_scaled_matmul): * [​`block_scaled_matmul_with_epilogue`](./block_scaled_matmul_with_epilogue): Our sm100 block scaled matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 block scaled matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. * [​`block_scales_interleave`](./block_scales_interleave): * [​`block_scales_interleave_fp4`](./block_scales_interleave_fp4): * [​`block_scales_interleave_fp4_kernel`](./block_scales_interleave_fp4_kernel): * [​`naive_block_scaled_matmul`](./naive_block_scaled_matmul): * [​`naive_block_scaled_matmul_kernel`](./naive_block_scaled_matmul_kernel): * [​`quantize_dynamic_block_scaled`](./quantize_dynamic_block_scaled): * [​`quantize_dynamic_scaled_async_fp4_kernel`](./quantize_dynamic_scaled_async_fp4_kernel): * [​`quantize_dynamic_scaled_fp4_async`](./quantize_dynamic_scaled_fp4_async): * [​`quantize_dynamic_scaled_fp4fp8`](./quantize_dynamic_scaled_fp4fp8): * [​`quantize_dynamic_scaled_fp4fp8_kernel`](./quantize_dynamic_scaled_fp4fp8_kernel):
--- ## naive_block_scaled_matmul
`naive_block_scaled_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scaling_kind: UMMAKind, SF_VECTOR_SIZE: Int, accum_type: DType = DType.float32, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, BLOCK_DIM: Int = 16](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext, alpha: Float32 = 1)`
--- ## naive_block_scaled_matmul_kernel
`naive_block_scaled_matmul_kernel[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, scaling_kind: UMMAKind, SF_VECTOR_SIZE: Int, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin], alpha: Float32)`
--- ## quantize_dynamic_block_scaled
`quantize_dynamic_block_scaled[out_dtype: DType, scales_dtype: DType, in_dtype: DType, //, *, SF_VECTOR_SIZE: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output_device: NDBuffer[out_dtype, 2, MutAnyOrigin, shape], scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], input_device: NDBuffer[in_dtype, 2, MutAnyOrigin, shape], tensor_sf: Float32, ctx: DeviceContext)`
--- ## quantize_dynamic_scaled_async_fp4_kernel
`quantize_dynamic_scaled_async_fp4_kernel[input_dtype: DType, input_cta_tile_layout: Layout, input_desc_layout: Layout, output_dtype: DType, output_cta_tile_layout: Layout, output_desc_layout: Layout, scales_dtype: DType, scales_tma_tile_layout: Layout, scales_desc_layout: Layout, input_swizzle_mode: TensorMapSwizzle, output_swizzle_mode: TensorMapSwizzle, scales_swizzle_mode: TensorMapSwizzle, SF_VECTOR_SIZE: Scalar[DType.uint], NUM_PIPELINES_STAGES: Scalar[DType.uint]](input_tma_op: TMATensorTile[input_dtype, input_cta_tile_layout, input_desc_layout], output_tma_op: TMATensorTile[output_dtype, output_cta_tile_layout, output_desc_layout], scales_tma_op: TMATensorTile[scales_dtype, scales_tma_tile_layout, scales_desc_layout], tensor_sf: Float32)`
--- ## quantize_dynamic_scaled_fp4_async
`quantize_dynamic_scaled_fp4_async[input_dtype: DType, output_dtype: DType, scales_dtype: DType, input_layout: Layout, output_layout: Layout, scales_layout: Layout, //, SF_VECTOR_SIZE: Int](ctx: DeviceContext, output_tensor: LayoutTensor[output_dtype, output_layout, MutAnyOrigin], scales_tensor: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], input_tensor: LayoutTensor[input_dtype, input_layout, MutAnyOrigin], tensor_sf: Float32 = 1)`
--- ## quantize_dynamic_scaled_fp4fp8
`quantize_dynamic_scaled_fp4fp8[out_dtype: DType, scales_dtype: DType, in_dtype: DType, output_layout: Layout, scales_layout: Layout, input_layout: Layout, //, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 512](ctx: DeviceContext, output: LayoutTensor[out_dtype, output_layout, MutAnyOrigin], scales: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], input: LayoutTensor[in_dtype, input_layout, MutAnyOrigin], num_cols: Int, num_cols_padded: Int, tensor_sf: Float32 = 1)`
--- ## quantize_dynamic_scaled_fp4fp8_kernel
`quantize_dynamic_scaled_fp4fp8_kernel[out_dtype: DType, scales_dtype: DType, in_dtype: DType, output_layout: Layout, scales_layout: Layout, input_layout: Layout, *, SF_VECTOR_SIZE: Int = 16, ELEMENTS_PER_THREAD: Int = 8, num_max_threads: Int = 512](output: LayoutTensor[out_dtype, output_layout, MutAnyOrigin], scales: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], input: LayoutTensor[in_dtype, input_layout, MutAnyOrigin], num_cols: Int, num_cols_padded: Int, tensor_sf: Float32)`
--- ## cast_f4e2m1x2_to_fp16x2
`cast_f4e2m1x2_to_fp16x2(x: UInt8) -> SIMD[DType.float16, 2]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## cast_fp32_to_fp4e2m1
`cast_fp32_to_fp4e2m1[width: Int, //](x: SIMD[DType.float32, width]) -> UInt32` **Returns:** `UInt32`
--- ## cast_fp_to_fp4e2m1
`cast_fp_to_fp4e2m1[dtype: DType, width: Int, //](x: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## cast_uint_to_fp4e2m1
`cast_uint_to_fp4e2m1[in_dtype: DType, in_width: Int, //, *, out_dtype: DType, out_width: Int](x: SIMD[in_dtype, in_width]) -> SIMD[out_dtype, out_width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## convert_ref_scales_to_mxfp8_format
`convert_ref_scales_to_mxfp8_format[ref_scales_type: DType, scales_type: DType, ref_a_scales_layout: Layout, ref_b_scales_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_scales_origin: MutOrigin, b_scales_origin: MutOrigin, *, REF_BLOCK_SIZE: Int, SF_VECTOR_SIZE: Int](m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim], ref_a_scales: LayoutTensor[ref_scales_type, ref_a_scales_layout, origin], ref_b_scales: LayoutTensor[ref_scales_type, ref_b_scales_layout, origin], a_scales: LayoutTensor[scales_type, a_scales_layout, a_scales_origin], b_scales: LayoutTensor[scales_type, b_scales_layout, b_scales_origin])`
--- ## get_batched_scale_factor
`get_batched_scale_factor[scales_dtype: DType, scales_layout: Layout, //, SF_VECTOR_SIZE: Int](scales_tensor: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], batch_idx: Int, row_idx: Int, col_idx: Int) -> Scalar[scales_dtype]` **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar)
--- ## get_scale_factor
`get_scale_factor[scales_dtype: DType, scales_layout: Layout, //, SF_VECTOR_SIZE: Int](scales_tensor: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], row_idx: Int, col_idx: Int) -> Scalar[scales_dtype]` **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar)
--- ## fp4_utils
## `comptime` values ### `E2M1_TO_FLOAT32` `comptime E2M1_TO_FLOAT32 = SIMD[DType.float32, 16](0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6, Tuple[]())` ### `MXFP4_SF_DTYPE` `comptime MXFP4_SF_DTYPE = DType.float8_e8m0fnu` ### `MXFP4_SF_VECTOR_SIZE` `comptime MXFP4_SF_VECTOR_SIZE = 32` ### `MXFP8_SF_DTYPE` `comptime MXFP8_SF_DTYPE = DType.float8_e8m0fnu` ### `MXFP8_SF_VECTOR_SIZE` `comptime MXFP8_SF_VECTOR_SIZE = 32` ### `NVFP4_SF_DTYPE` `comptime NVFP4_SF_DTYPE = DType.float8_e4m3fn` ### `NVFP4_SF_VECTOR_SIZE` `comptime NVFP4_SF_VECTOR_SIZE = 16` ### `SF_ATOM_K` `comptime SF_ATOM_K = 4` ### `SF_ATOM_M` `comptime SF_ATOM_M = Tuple[Int, Int](VariadicPack[True, MutExternalOrigin, True, Movable, Int, Int](32, 4))` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE[SF_VECTOR_SIZE: Int] = (4 * SF_VECTOR_SIZE)` #### Parameters * ​SF\_VECTOR\_SIZE ([`Int`](/std/builtin/int/Int)): ### `SF_MN_GROUP_SIZE` `comptime SF_MN_GROUP_SIZE = ((load_from_mem SF_ATOM_M.__getitem__[Int, Int, 0]()) * (load_from_mem SF_ATOM_M.__getitem__[Int, Int, 1]()))` ## Functions * [​`cast_f4e2m1x2_to_fp16x2`](./cast_f4e2m1x2_to_fp16x2): * [​`cast_fp32_to_fp4e2m1`](./cast_fp32_to_fp4e2m1): * [​`cast_fp_to_fp4e2m1`](./cast_fp_to_fp4e2m1): * [​`cast_uint_to_fp4e2m1`](./cast_uint_to_fp4e2m1): * [​`convert_ref_scales_to_mxfp8_format`](./convert_ref_scales_to_mxfp8_format): * [​`get_batched_scale_factor`](./get_batched_scale_factor): * [​`get_scale_factor`](./get_scale_factor): * [​`set_batched_scale_factor`](./set_batched_scale_factor): * [​`set_scale_factor`](./set_scale_factor):
--- ## set_batched_scale_factor
`set_batched_scale_factor[scales_dtype: DType, scales_layout: Layout, //, SF_VECTOR_SIZE: Int](scales_tensor: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], batch_idx: Int, row_idx: Int, col_idx: Int, scale_value: Scalar[scales_dtype])`
--- ## set_scale_factor
`set_scale_factor[scales_dtype: DType, scales_layout: Layout, //, SF_VECTOR_SIZE: Int, width: Int](scales_tensor: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], row_idx: Int, col_idx: Int, scale_value: SIMD[scales_dtype, width])`
--- ## batched_quantize_dynamic_scaled_fp8
`batched_quantize_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, scales_dtype: DType, //, input_fn: fn[width: Int, alignment: Int](batch: Int, row: Int, col: Int) capturing -> SIMD[in_dtype, width], group_size_or_per_token: Int, num_cols: Int](scaled_output: NDBuffer[out_dtype, 3, MutAnyOrigin], scales: NDBuffer[scales_dtype, 3, MutAnyOrigin], scale_ub: Float32, ctx: DeviceContext, num_rows: Int, batch_size: Int)`
--- ## batched_quantize_fp8_kernel
`batched_quantize_fp8_kernel[out_type: DType, scales_type: DType, in_type: DType, input_fn: fn[width: Int, alignment: Int](batch: Int, row: Int, col: Int) capturing -> SIMD[in_type, width], num_threads: Int, group_size: Int, simd_width: Int](output: NDBuffer[out_type, 3, MutAnyOrigin], scales: NDBuffer[scales_type, 3, MutAnyOrigin], scale_ub: Scalar[scales_type])`
--- ## blockwise_scaled_fp8_with_epilogue
`blockwise_scaled_fp8_with_epilogue[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scales_granularity_mnk: IndexList[3], transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` Our sm100 blockwise scaled fp8 matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 blockwise scaled fp8 matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. For non B200 GPUs, we use the naive blockwise scaled fp8 matmul which support normal epilogue natively.
--- ## convert_e4m3fn_to_e4m3fnuz
`convert_e4m3fn_to_e4m3fnuz(input_buffer: LayoutTensor[DType.float8_e4m3fn, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_buffer: LayoutTensor[DType.float8_e4m3fnuz, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContext)` Convert E4M3FN weights to E4M3FNUZ format for AMD GPU compatibility. This conversion handles the key differences between E4M3FN and E4M3FNUZ: 1. The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ **Args:** * ​input\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor in E4M3FN format. * ​output\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor to store E4M3FNUZ format. * ​context ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel execution.
--- ## fp8_quantization
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`batched_quantize_dynamic_scaled_fp8`](./batched_quantize_dynamic_scaled_fp8): * [​`batched_quantize_fp8_kernel`](./batched_quantize_fp8_kernel): * [​`blockwise_scaled_fp8_with_epilogue`](./blockwise_scaled_fp8_with_epilogue): Our sm100 blockwise scaled fp8 matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 blockwise scaled fp8 matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. For non B200 GPUs, we use the naive blockwise scaled fp8 matmul which support normal epilogue natively. * [​`convert_e4m3fn_to_e4m3fnuz`](./convert_e4m3fn_to_e4m3fnuz): Convert E4M3FN weights to E4M3FNUZ format for AMD GPU compatibility. * [​`matmul_dynamic_scaled_fp8`](./matmul_dynamic_scaled_fp8): * [​`naive_blockwise_scaled_fp8_grouped_matmul`](./naive_blockwise_scaled_fp8_grouped_matmul): * [​`naive_blockwise_scaled_fp8_grouped_matmul_kernel`](./naive_blockwise_scaled_fp8_grouped_matmul_kernel): * [​`naive_blockwise_scaled_fp8_matmul`](./naive_blockwise_scaled_fp8_matmul): * [​`naive_blockwise_scaled_fp8_matmul_kernel`](./naive_blockwise_scaled_fp8_matmul_kernel): * [​`quantize_dynamic_scaled_fp8`](./quantize_dynamic_scaled_fp8): * [​`quantize_fp8_kernel`](./quantize_fp8_kernel): * [​`quantize_static_scaled_fp8`](./quantize_static_scaled_fp8):
--- ## matmul_dynamic_scaled_fp8
`matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: TileTensor[c_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], a: TileTensor[a_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], b: TileTensor[b_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], a_scales: TileTensor[a_scales_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], b_scales: TileTensor[b_scales_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)` `matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[c_type, 2, origin, shape, strides], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], a_scales: NDBuffer[a_scales_type, 2, origin, shape, strides], b_scales: NDBuffer[b_scales_type, 2, origin, shape, strides], ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_grouped_matmul
`naive_blockwise_scaled_fp8_grouped_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, //, BLOCK_DIM_N: Int = 32, BLOCK_DIM_M: Int = 16, transpose_b: Bool = True, scales_granularity_mnk: Optional[IndexList[3]] = None, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_grouped_matmul_kernel
`naive_blockwise_scaled_fp8_grouped_matmul_kernel[c_layout: Layout, a_layout: Layout, b_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, accum_type: DType, transpose_b: Bool = True, scales_granularity_mnk: Optional[IndexList[3]] = None, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin])`
--- ## naive_blockwise_scaled_fp8_matmul
`naive_blockwise_scaled_fp8_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, BLOCK_DIM: Int = 16, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, accum_type: DType = get_accum_type[c_type](), scales_granularity_mnk: Optional[IndexList[3]] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` `naive_blockwise_scaled_fp8_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, c_shape: DimList, a_shape: DimList, b_shape: DimList, a_scale_shape: DimList, b_scale_shape: DimList, //, *, BLOCK_DIM: Int = 16, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, accum_type: DType = get_accum_type[c_type](), scales_granularity_mnk: Optional[IndexList[3]] = None](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], a_scales_device: NDBuffer[a_scales_type, 2, origin, a_scale_shape], b_scales_device: NDBuffer[b_scales_type, 2, origin, b_scale_shape], ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_matmul_kernel
`naive_blockwise_scaled_fp8_matmul_kernel[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, BLOCK_DIM: Int, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, scales_granularity_mnk: Optional[IndexList[3]] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin])`
--- ## quantize_dynamic_scaled_fp8
`quantize_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, scales_dtype: DType, //, input_fn: fn[width: Int, alignment: Int](row: Int, col: Int) capturing -> SIMD[in_dtype, width], group_size_or_per_token: Int, num_cols: Int](scaled_output: NDBuffer[out_dtype, 2, MutAnyOrigin], scales: NDBuffer[scales_dtype, 2, MutAnyOrigin], scale_ub: Float32, ctx: DeviceContext, num_rows: Int)`
--- ## quantize_fp8_kernel
`quantize_fp8_kernel[out_type: DType, scales_type: DType, in_type: DType, input_fn: fn[width: Int, alignment: Int](row: Int, col: Int) capturing -> SIMD[in_type, width], num_threads: Int, group_size: Int, simd_width: Int](output: NDBuffer[out_type, 2, MutAnyOrigin], scales: NDBuffer[scales_type, 2, MutAnyOrigin], scale_ub: Scalar[scales_type])`
--- ## quantize_static_scaled_fp8
`quantize_static_scaled_fp8[out_dtype: DType, in_dtype: DType, scale_is_inverted: Bool = True](out_buffer: NDBuffer[out_dtype, 2, origin, shape, strides], in_buffer: NDBuffer[in_dtype, 2, origin, shape, strides], scale: Float32, context: DeviceContext)`
--- ## compute_dynamic_fp8_scale
`compute_dynamic_fp8_scale[out_dtype: DType](row_max: Scalar[dtype], scale_ub: Scalar[dtype]) -> Tuple[Scalar[dtype], Scalar[dtype]]` Compute dynamic FP8 scale factor and its reciprocal from a row max. Computes scale\_factor = min(row\_max, scale\_ub) / fp8\_max and its reciprocal. Does not use `math.recip` to avoid a reciprocal approximation that gives up too much precision. **Parameters:** * ​out\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The FP8 output dtype (float8\_e4m3fn or float8\_e4m3fnuz). **Args:** * ​row\_max ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Maximum absolute value across the row/group. * ​scale\_ub ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Upper bound to clamp the scale factor. **Returns:** `Tuple`: A tuple of (scale\_factor, scale\_factor\_recip).
--- ## compute_static_fp8_scale_recip
`compute_static_fp8_scale_recip[accum_type: DType, out_dtype: DType](static_scale: Float32) -> Scalar[accum_type]` Compute reciprocal scale for static FP8 quantization. **Parameters:** * ​accum\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The accumulation dtype. * ​out\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The FP8 output dtype. **Args:** * ​static\_scale ([`Float32`](/mojo/std/builtin/simd/#float32)): The static scale value. **Returns:** [`Scalar`](/mojo/std/builtin/simd/#scalar): The reciprocal scale: fp8\_max / static\_scale.
--- ## fp8_quantize
`fp8_quantize[out_dtype: DType, *, use_clamp: Bool = is_amd_gpu()](values: SIMD[dtype, size], scale_recip: Scalar[dtype]) -> SIMD[out_dtype, size]` Quantize values to FP8, optionally clamping to the representable range. On AMD, using clamp is faster because of nan handling. **Parameters:** * ​out\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The FP8 output dtype. * ​use\_clamp ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to clamp to \[min\_finite, max\_finite] before cast. Defaults to True on AMD GPU, False otherwise. **Args:** * ​values ([`SIMD`](/mojo/std/builtin/simd/SIMD)): Values to quantize (already normalized as needed, not yet scaled). * ​scale\_recip ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Reciprocal of the FP8 scale factor. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): FP8-quantized values.
--- ## fp8_utils
Shared FP8 quantization utilities. Provides common functions for FP8 scale computation and quantization used across fused normalization kernels and standalone quantization kernels. ## Functions * [​`compute_dynamic_fp8_scale`](./compute_dynamic_fp8_scale): Compute dynamic FP8 scale factor and its reciprocal from a row max. * [​`compute_static_fp8_scale_recip`](./compute_static_fp8_scale_recip): Compute reciprocal scale for static FP8 quantization. * [​`fp8_quantize`](./fp8_quantize): Quantize values to FP8, optionally clamping to the representable range.
--- ## GEMVAlgorithm
`struct GEMVAlgorithm` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `GEMV_KERNEL` `comptime GEMV_KERNEL = GEMVAlgorithm(0)` ### `GEMV_KERNEL_VECTOR` `comptime GEMV_KERNEL_VECTOR = GEMVAlgorithm(1)` ### `GEMV_SPLIT_K` `comptime GEMV_SPLIT_K = GEMVAlgorithm(2)` ### `GEVM_KERNEL` `comptime GEVM_KERNEL = GEMVAlgorithm(4)` ### `GEVM_KERNEL_VECTOR` `comptime GEVM_KERNEL_VECTOR = GEMVAlgorithm(3)` ### `MATMUL_NAIVE` `comptime MATMUL_NAIVE = GEMVAlgorithm(5)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__is__` `__is__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` Returns the string representation of this algorithm. **Returns:** [`String`](/mojo/std/collections/string/string/String): String: A human-readable string representation of the algorithm. ### `write_to` `write_to(self, mut writer: T)`
--- ## gemv
`gemv[parallelize: Bool, c_size: Dim, c_type: DType, a_shape: DimList, a_type: DType, b_size: Dim, b_type: DType, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c_buf: NDBuffer[c_type, 1, origin, c_size], a_buf: NDBuffer[a_type, 2, origin, a_shape], b_buf: NDBuffer[b_type, 1, origin, b_size])`
--- ## gemv_gpu
`gemv_gpu[transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], ctx: DeviceContext)`
--- ## gemv_gpu_dispatch
`gemv_gpu_dispatch[transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, pdl_level: PDLLevel = PDLLevel()](kernel_func: GEMVAlgorithm, c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], ctx: DeviceContext)`
--- ## gemv_kernel
`gemv_kernel[c_type: DType, a_type: DType, b_type: DType, *, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LegacyUnsafePointer[Scalar[c_type]], a: LegacyUnsafePointer[Scalar[a_type]], b: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)`
--- ## gemv_kernel_vector
`gemv_kernel_vector[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, *, simd_width: Scalar[DType.uint], transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## gemv_split_k
`gemv_split_k[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, simd_width: Scalar[DType.uint], tile_m: Scalar[DType.uint], tile_n: Scalar[DType.uint], num_threads: Scalar[DType.uint], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type](), check_bounds: Bool = True, pdl_level: PDLLevel = PDLLevel()](output: LayoutTensor[c_type, c_layout, MutAnyOrigin], act: LayoutTensor[a_type, a_layout, MutAnyOrigin], weight: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)` GEMV with tiling in K dimension. Assuming the B (weight) matrix is transposed i.e. row major N x K, this kernel implements a vector (1 x K) times a matrix (N x K). The impl can actually handle M > 1 but it's only optimal for tiny M. We use it for M = 1 only.
--- ## gevm_kernel
`gevm_kernel[c_type: DType, a_type: DType, b_type: DType, *, tile_size: Int, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LegacyUnsafePointer[Scalar[c_type]], a: LegacyUnsafePointer[Scalar[a_type]], b: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)`
--- ## gemv (Gemv)
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`GEMVAlgorithm`](./GEMVAlgorithm): ## Functions * [​`gemv`](./gemv): * [​`gemv_gpu`](./gemv_gpu): * [​`gemv_gpu_dispatch`](./gemv_gpu_dispatch): * [​`gemv_kernel`](./gemv_kernel): * [​`gemv_kernel_vector`](./gemv_kernel_vector): * [​`gemv_split_k`](./gemv_split_k): GEMV with tiling in K dimension. Assuming the B (weight) matrix is transposed i.e. row major N x K, this kernel implements a vector (1 x K) times a matrix (N x K). The impl can actually handle M > 1 but it's only optimal for tiny M. We use it for M = 1 only. * [​`gevm_kernel`](./gevm_kernel): * [​`log_shape`](./log_shape): * [​`naive_gemv`](./naive_gemv): * [​`reverse_idx`](./reverse_idx):
--- ## log_shape
`log_shape[has_mode_1: Bool, has_mode_2: Bool, name: String](mode_1: Int, mode_2: Int)`
--- ## naive_gemv
`naive_gemv[c_size: Dim, a_shape: DimList, b_size: Dim, dtype: DType](c_buf: NDBuffer[dtype, 1, origin, c_size], a_buf: NDBuffer[dtype, 2, origin, a_shape], b_buf: NDBuffer[dtype, 1, origin, b_size])`
--- ## reverse_idx
`reverse_idx[transpose: Bool](x: Int, y: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## dispatch_amd_matmul_by_block_shape
`dispatch_amd_matmul_by_block_shape[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool, N: Int, K: Int, launcher_fn: fn[config: MatmulConfig[a_type, b_type, c_type, transpose_b]]() raises capturing -> None, default_block_tile_shape: IndexList[3], use_heuristic: Bool = False](M: Int, ctx: DeviceContext)` Dispatches to the best kernel configuration based on runtime M dimension.
--- ## grouped_matmul
`grouped_matmul[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_amd
`grouped_matmul_amd[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool = True, block_tile_shape: IndexList[3] = Index(128, 128, 64), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_amd_kernel_launcher
`grouped_matmul_amd_kernel_launcher[c_type: DType, a_type: DType, b_type: DType, layout_c: Layout, layout_a: Layout, layout_b: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c_tensor: LayoutTensor[c_type, layout_c, MutAnyOrigin], a_tensor: LayoutTensor[a_type, layout_a, MutAnyOrigin], b_tensor: LayoutTensor[b_type, layout_b, MutAnyOrigin], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int)`
--- ## grouped_matmul_kernel_sm100
`grouped_matmul_kernel_sm100[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, c_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], a_desc_layout: Layout, b_desc_layout: Layout, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, transpose_b: Bool = True, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)`
--- ## grouped_matmul_sm100
`grouped_matmul_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, mma_shape: IndexList[3] = Index(64, 128, 16), block_tile_shape: IndexList[3] = Index(64, 128, 64), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_vendor
`grouped_matmul_vendor[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool = True, use_tf32: Bool = False](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul (Grouped_matmul)
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`dispatch_amd_matmul_by_block_shape`](./dispatch_amd_matmul_by_block_shape): Dispatches to the best kernel configuration based on runtime M dimension. * [​`grouped_matmul`](./grouped_matmul): * [​`grouped_matmul_amd`](./grouped_matmul_amd): * [​`grouped_matmul_amd_kernel_launcher`](./grouped_matmul_amd_kernel_launcher): * [​`grouped_matmul_kernel_sm100`](./grouped_matmul_kernel_sm100): * [​`grouped_matmul_sm100`](./grouped_matmul_sm100): * [​`grouped_matmul_vendor`](./grouped_matmul_vendor): * [​`naive_epilogue`](./naive_epilogue): * [​`naive_epilogue_kernel`](./naive_epilogue_kernel): * [​`naive_grouped_matmul`](./naive_grouped_matmul): * [​`naive_grouped_matmul_kernel`](./naive_grouped_matmul_kernel):
--- ## naive_epilogue
`naive_epilogue[c_type: DType, c_shape: DimList, *, elementwise_lambda_fn: elementwise_epilogue_type](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], ctx: DeviceContext)`
--- ## naive_epilogue_kernel
`naive_epilogue_kernel[c_type: DType, c_shape: DimList, *, elementwise_lambda_fn: elementwise_epilogue_type](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape])`
--- ## naive_grouped_matmul
`naive_grouped_matmul[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## naive_grouped_matmul_kernel
`naive_grouped_matmul_kernel[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin])`
--- ## WarpRole
`@register_passable(trivial)` `struct WarpRole` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole(3)` ### `MainLoad` `comptime MainLoad = WarpRole(4)` ### `Mma` `comptime Mma = WarpRole(5)` ## Methods ### `__eq__` `__eq__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## blackwell_tma_umma_warp_specialized_kernel
`blackwell_tma_umma_warp_specialized_kernel[a_type: DType, b_type: DType, c_type: DType, expert_m: Int, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Scalar[DType.uint], num_accum_pipeline_stages: Int, num_output_stages: Int = 2, output_tile_shape: IndexList[2] = Index(128, 32), transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 2, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, transpose_c: Bool = False](num_active_experts: Int, a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], b_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], mnk: StaticTuple[UInt32, 3])`
--- ## consumer_main_loop
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1)](tmem_addr: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], mma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumer_phase: PipelineState[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32)`
--- ## grouped_matmul_sm100_persistent
`grouped_matmul_sm100_persistent[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cta_group: Int = 1, num_pipeline_stages: Optional[UInt] = None, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100 (Grouped_matmul_sm100)
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`WarpRole`](./WarpRole): ## Functions * [​`blackwell_tma_umma_warp_specialized_kernel`](./blackwell_tma_umma_warp_specialized_kernel): * [​`consumer_main_loop`](./consumer_main_loop): * [​`grouped_matmul_sm100_persistent`](./grouped_matmul_sm100_persistent): * [​`load_AB`](./load_AB): * [​`multi_stage_store_C`](./multi_stage_store_C): * [​`stsm_helper`](./stsm_helper): * [​`zero_output`](./zero_output):
--- ## load_AB
`load_AB[a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, num_pipeline_stages: Scalar[DType.uint], /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], mma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_phase: PipelineState[Int.__init__[Scalar[DType.uint]](num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB])`
--- ## multi_stage_store_C
`multi_stage_store_C[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, /, *, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: Int = 4, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, transpose_c: Bool = False](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], accum_pipeline_consumer_state: PipelineState[num_accum_pipeline_stages], accum_full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], accum_empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_addr: UInt32, work_tile_coord: Tuple[UInt, UInt], group_end_idx: UInt32, elect_one_warp: Bool, M: UInt32, N: UInt32)`
--- ## stsm_helper
`stsm_helper[swizzle: Swizzle, transpose_c: Bool = False](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## zero_output
`zero_output[c_type: DType, c_layout: Layout, *, output_tile_shape: IndexList[2]](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], coord: Tuple[UInt32, UInt32], group_end_idx: UInt32)`
--- ## B200BlockScaledMatmulSmem
`struct B200BlockScaledMatmulSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]` ## Fields * ​a\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].a_smem_size]`): * ​b\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].b_smem_size]`): * ​c\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].c_smem_size]`): * ​sfa\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AScalesType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfa_smem_size]`): * ​sfb\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BScalesType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfb_smem_size]`): * ​tma\_mma\_mbars (`InlineArray[SharedMemBarrier, (B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages * 2)]`): * ​accum\_mbars (`InlineArray[SharedMemBarrier, (config * 2)]`): * ​tmem\_dealloc\_mbar (`InlineArray[SharedMemBarrier, 1]`): * ​tmem\_addr (`InlineArray[UInt32, 1]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_size` `comptime a_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK) * config)` ### `AScalesType` `comptime AScalesType = Scalar[sfa_dtype]` ### `AType` `comptime AType = Scalar[a_type]` ### `b_smem_size` `comptime b_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK) * config)` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BScalesType` `comptime BScalesType = Scalar[sfb_dtype]` ### `BType` `comptime BType = Scalar[b_type]` ### `c_smem_size` `comptime c_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN) * config)` ### `CType` `comptime CType = Scalar[c_type]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (config // config)` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `sfa_smem_size` `comptime sfa_smem_size = (((config * (B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM // SF_MN_GROUP_SIZE)) * BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].sf_block_atom_size) * config)` ### `sfb_smem_size` `comptime sfb_smem_size = (((config * (B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N // SF_MN_GROUP_SIZE)) * BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].sf_block_atom_size) * config)`
--- ## blackwell_block_scaled_matmul_tma_umma_warp_specialized
`blackwell_block_scaled_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, b_type: DType, b_layout: Layout, expert_ids_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, _sfb_layout: Layout, expert_scale_layout: Layout, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: Optional[UInt32] = None](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], _b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], _b_scales: LayoutTensor[sfb_dtype, _sfb_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scale_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## blackwell_block_scaled_tma_umma_warp_specialized_kernel
`blackwell_block_scaled_tma_umma_warp_specialized_kernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, sfa_tile_layout: Layout, sfb_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], expert_n: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0](num_active_experts: Int, a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_tile_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_tile_layout, sfb_desc_layout], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])`
--- ## consumer_main_loop (Grouped_matmul_sm100_1d1d)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, sfb_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, scaling_kind: UMMAKind, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], SFA_NUM_COLS: Int, SFB_NUM_COLS: Int, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), k_group_size: Int = 1](tmem_addr: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfa_smem_iter: LayoutTensorIter[sfa_dtype, sfa_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfb_smem_iter: LayoutTensorIter[sfb_dtype, sfb_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, scaling_kind, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)`
--- ## copy_accum_to_gmem
`copy_accum_to_gmem[c_type: DType, c_layout: Layout, c_smem_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, c_tensor_layout: Layout, /, *, repeat: Int, accum_type: DType, cta_group: Int, epilogue_dtype: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_output_warps: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False, scale_c_coord: Bool = True](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], mma_output_pipeline: ProducerConsumerPipeline[num_accum_pipeline_stages], mma_output_stage: UInt32, tmem_offset: UInt32, c_coord: Tuple[UInt32, UInt32], c_shape: Tuple[UInt32, UInt32], expert_scale: Float32)`
--- ## grouped_matmul_dynamic_scaled_nvfp4
`grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)` Performs grouped matrix multiplication with NVFP4 quantization. Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8\_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling). **Constraints:** * The target device must be SM100 (B200). **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the output tensor C. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the output tensor C. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of input tensor A. Constraints: Must be `uint8`. * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of input tensor A. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of input tensor B. Constraints: Must be `uint8`. * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of input tensor B. * ​scales\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of scale factors. Constraints: Must be `float8_e4m3fn`. * ​a\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of A's scale factors. * ​b\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of B's scale factors. * ​a\_offsets\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the token offset indices. * ​a\_scale\_offsets\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of A's scale offset indices. * ​expert\_ids\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the expert ID tensor. * ​expert\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the per-expert scale tensor. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether B is transposed. Constraints: Must be `True`. * ​target (`StringSlice`): The target device. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor of shape (total\_tokens, N). * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor of shape (total\_tokens, K // 2), packed NVFP4. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The weight tensor of shape (num\_experts, N, K // 2), packed NVFP4. * ​a\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The scale factors for A in tcgen05 5D layout. * ​b\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The scale factors for B in tcgen05 6D layout. * ​a\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The starting token index for each expert group. * ​a\_scale\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The starting scale index for each expert group. * ​expert\_ids ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The expert ID for each group. * ​expert\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The per-expert scaling factors applied in the epilogue. * ​num\_active\_experts ([`Int`](/mojo/std/builtin/int/Int)): The number of active experts in this batch. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context for GPU execution.
--- ## grouped_matmul_sm100_1d1d
## `comptime` values ### `WarpRole` `comptime WarpRole = WarpRole[False]` ## Structs * [​`B200BlockScaledMatmulSmem`](./B200BlockScaledMatmulSmem): ## Functions * [​`blackwell_block_scaled_matmul_tma_umma_warp_specialized`](./blackwell_block_scaled_matmul_tma_umma_warp_specialized): * [​`blackwell_block_scaled_tma_umma_warp_specialized_kernel`](./blackwell_block_scaled_tma_umma_warp_specialized_kernel): * [​`consumer_main_loop`](./consumer_main_loop): * [​`copy_accum_to_gmem`](./copy_accum_to_gmem): * [​`grouped_matmul_dynamic_scaled_nvfp4`](./grouped_matmul_dynamic_scaled_nvfp4): Performs grouped matrix multiplication with NVFP4 quantization. * [​`load_AB`](./load_AB): * [​`multi_stage_store_C`](./multi_stage_store_C):
--- ## load_AB (Grouped_matmul_sm100_1d1d)
`load_AB[a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, sfb_smem_layout: Layout, num_pipeline_stages: Int, a_scale_offsets_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_sf_k_tiles: Int, cta_group: Int = 1, k_group_size: Scalar[DType.uint] = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfa_smem: LayoutTensorIter[sfa_dtype, sfa_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfb_smem: LayoutTensorIter[sfb_dtype, sfb_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[num_pipeline_stages], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB], expert_id: Int32, a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin])`
--- ## multi_stage_store_C (Grouped_matmul_sm100_1d1d)
`multi_stage_store_C[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, c_tensor_layout: Layout, /, *, input_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: Int = 4, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False, scale_c_coord: Bool = True](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], mma_output_pipeline: ProducerConsumerPipeline[num_accum_pipeline_stages], tmem_addr: UInt32, work_tile_coord: Tuple[UInt32, UInt32], elect_one_warp: Bool, expert_scale: Float32, M: UInt32, N: UInt32)`
--- ## blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel
`blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, a_scales_tile_layout: Layout, a_scales_type: DType, a_offsets_layout: Layout, b_scales_type: DType, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, a_scales_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], num_pipeline_stages: Scalar[DType.uint], cluster_shape: StaticTuple[Int32, 3], expert_n: Int, expert_ids_layout: Layout, b_scales_n: Int](num_active_experts: Int, a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], num_iters: Scalar[DType.uint], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], problem_shape: StaticTuple[Int32, 3])`
--- ## grouped_matmul_dynamic_scaled_fp8
`grouped_matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, tokens_padded_per_expert: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[c_type, 2, MutAnyOrigin, shape], a: NDBuffer[a_type, 2, MutAnyOrigin, shape], b: NDBuffer[b_type, 3, MutAnyOrigin, shape], a_scales: NDBuffer[a_scales_type, 2, MutAnyOrigin, shape], b_scales: NDBuffer[b_scales_type, 3, MutAnyOrigin, shape], a_offsets: NDBuffer[a_offsets_type, 1, MutAnyOrigin, shape], expert_ids: NDBuffer[expert_ids_type, 1, MutAnyOrigin, shape], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_scaled_fp8
`grouped_matmul_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, transpose_b: Bool, //, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_scaled_fp8_persistent
`grouped_matmul_sm100_blockwise_scaled_fp8_persistent[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, transpose_b: Bool, //, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_fp8
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel`](./blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel): * [​`grouped_matmul_dynamic_scaled_fp8`](./grouped_matmul_dynamic_scaled_fp8): * [​`grouped_matmul_sm100_blockwise_scaled_fp8`](./grouped_matmul_sm100_blockwise_scaled_fp8): * [​`grouped_matmul_sm100_blockwise_scaled_fp8_persistent`](./grouped_matmul_sm100_blockwise_scaled_fp8_persistent): * [​`load_AB`](./load_AB): * [​`matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel`](./matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel): * [​`multi_stage_reg_epilogue`](./multi_stage_reg_epilogue): * [​`promote_accumulators`](./promote_accumulators):
--- ## load_AB (Grouped_matmul_sm100_blockwise_fp8)
`load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_layout: Layout, b_layout: Layout, a_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, a_scales_smem_layout: Layout, num_pipeline_stages: Scalar[DType.uint], expert_ids_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], a_scales_smem: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[Int.__init__[Scalar[DType.uint]](num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: Scalar[DType.uint], elect_one_cta: Bool, scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin])`
--- ## matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel
`matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Scalar[DType.uint] = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])`
--- ## multi_stage_reg_epilogue
`multi_stage_reg_epilogue[c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, accum_type: DType, accum_layout: Layout, c_tensor_layout: Layout, /, *, c_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle](c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], c_coord: Tuple[UInt, UInt], elect_one_warp: Bool, group_end_idx: UInt32)`
--- ## promote_accumulators
`promote_accumulators[pipeline_stages: Scalar[DType.uint], num_accum_pipeline_stages: Int, accum_type: DType, accum_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_scales_smem_layout: Layout, expert_ids_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, CLUSTER_SIZE: Int32, is_lower_frag_required: Bool, num_output_warps: Int](b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], b_scales_n: Int, a_scales_smem_iter: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_output_pipeline: ProducerConsumerPipeline[num_accum_pipeline_stages], tmem_addr: UInt32, load_mma_pipeline: ProducerConsumerPipeline[Int.__init__[Scalar[DType.uint]](pipeline_stages)], work_tile_coord: Tuple[UInt, UInt], elect_one_warp: Bool, stage_stride_cols: Scalar[DType.uint], k_iter: Scalar[DType.uint], problem_shape: StaticTuple[Int32, 3], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB])`
--- ## RasterOrder
`@register_passable(trivial)` `struct RasterOrder` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AlongM` `comptime AlongM = RasterOrder(1)` ### `AlongN` `comptime AlongN = RasterOrder(0)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## TileScheduler
`@register_passable(trivial)` `struct TileScheduler[offsets_layout: Layout, //, *, static_MN: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index(1, 1, 1), cta_group: Int = 1, swizzle: Bool = False, swapAB: Bool = True]` ## Fields * ​num\_active\_experts (`Int`): * ​group\_offsets (`LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin]`): * ​current\_iter (`Int32`): * ​current\_group\_idx (`UInt32`): * ​current\_dynamic\_dim\_cumsum (`UInt32`): * ​block\_idx\_start (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cta_group_tile_shape` `comptime cta_group_tile_shape = Index((tile_shape.__getitem__[3, DType.int64, Int](0) * cta_group), (tile_shape.__getitem__[3, DType.int64, Int](1) * cta_group))` ### `div_dynamic_block` `comptime div_dynamic_block = FastDiv[DType.uint32](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].cta_group_tile_shape.__getitem__[2, DType.int64, Int](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].dynamic_dim))` ### `dynamic_dim` `comptime dynamic_dim = 1 if swapAB else 0` ### `kNum1DBlocksPerGroup` `comptime kNum1DBlocksPerGroup = 16` ### `num_static_dim_blocks` `comptime num_static_dim_blocks = SIMD[DType.uint32, 1](ceildiv(static_MN, tile_shape.__getitem__[3, DType.int64, Int](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].static_dim)))` ### `static_dim` `comptime static_dim = 0 if swapAB else 1` ## Methods ### `__init__` `__init__(num_active_experts: Int, group_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin]) -> Self` ### `fetch_next_work` `fetch_next_work(mut self) -> WorkInfo` **Returns:** `WorkInfo`
--- ## WorkInfo
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​is\_valid\_tile (`Bool`): * ​terminate (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_done` `is_done(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## grouped_matmul_tile_scheduler
## Structs * [​`RasterOrder`](./RasterOrder): * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## linalg
Provides CPU and GPU implementations of linear algebra functions. ## Packages * [​`arch`](./arch/): Provides architecture specific utility functions. * [​`matmul`](./matmul/): Provides the backend implementation for matmuls. ## Modules * [​`accumulate`](./accumulate/): * [​`bmm`](./bmm/): * [​`distributed_matmul`](./distributed_matmul/): * [​`dual_gemm`](./dual_gemm/): * [​`fp4_quantization`](./fp4_quantization/): * [​`fp4_utils`](./fp4_utils/): * [​`fp8_quantization`](./fp8_quantization/): * [​`fp8_utils`](./fp8_utils/): Shared FP8 quantization utilities. * [​`gemv`](./gemv/): * [​`grouped_matmul`](./grouped_matmul/): * [​`grouped_matmul_sm100`](./grouped_matmul_sm100/): * [​`grouped_matmul_sm100_1d1d`](./grouped_matmul_sm100_1d1d/): * [​`grouped_matmul_sm100_blockwise_fp8`](./grouped_matmul_sm100_blockwise_fp8/): * [​`grouped_matmul_tile_scheduler`](./grouped_matmul_tile_scheduler/): * [​`lora`](./lora/): * [​`matrix_band_part`](./matrix_band_part/): The module implements matrix band part functions. * [​`packing`](./packing/): * [​`qr_factorization`](./qr_factorization/): * [​`structuring`](./structuring/): * [​`transpose`](./transpose/): The module implements Transpose functions. * [​`utils`](./utils/): * [​`utils_gpu`](./utils_gpu/):
--- ## lora (Lora)
## Functions * [​`shrink_qkv_permute_3mn_sm100`](./shrink_qkv_permute_3mn_sm100): LoRA shrink GMM with planar Q/K/V output on SM100.
--- ## shrink_qkv_permute_3mn_sm100
`shrink_qkv_permute_3mn_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList](c_lora: NDBuffer[c_type, 3, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)` LoRA shrink GMM with planar Q/K/V output on SM100. Performs the LoRA 'shrink' grouped matmul for routed tokens: computes `[M, K] @ [G, 3N, K]^T` per active expert, then **permutes** the flat `[M, 3N]` result into a planar layout `[3, M, N]` (Q, K, V) using an elementwise epilogue, while reusing the same storage. **Constraints:** * c\_lora must be rank 3 with static first dimension B == 3. * a must be rank 2 with trailing dimension K that matches b\[..., K]. * b must be rank 3 with shape (G, 3N, K). * The temporary 2D view of c\_lora is (M, 3N) in row-major order and **aliases the same storage** as c\_lora. * a\_offsets is non-decreasing with a\_offsets\[0] == 0 and a\_offsets\[num\_active\_experts] == M. * expert\_ids\[i] ∈ \[0, G) for valid experts; kernel may treat -1 as inactive. * The epilogue assumes `N % vector_width == 0` for aligned vector stores. **Args:** * ​c\_lora ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output tensor with planar Q/K/V layout, shape (3, M, N). Backed by row-major storage, used both as a 3D view and as a temporary 2D view (M, 3N) during compute. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Routed activation matrix, shape (M, K). * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Shrink weights per expert, shape (G, 3N, K). * ​a\_offsets ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Inclusive prefix sums of tokens per (active) expert, length (num\_experts + 1). Defines per-expert \[start, end) in A/C. * ​expert\_ids ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Expert indices for the active groups, length ≥ num\_active\_experts. * ​max\_num\_tokens\_per\_expert ([`Int`](/mojo/std/builtin/int/Int)): Upper bound on tokens for any active expert. * ​num\_active\_experts ([`Int`](/mojo/std/builtin/int/Int)): Number of experts participating in this call. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): DeviceContext used for enqueues and synchronization.
--- ## apple_batched_matmul
`apple_batched_matmul[*, transpose_b: Bool = False, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])`
--- ## apple_gemv
`apple_gemv[*, b_packed: Bool, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape])`
--- ## apple_matmul
`apple_matmul[*, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](cblas_gemm_fn: cblas_gemm_type, c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])` `apple_matmul[*, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])`
--- ## get_cblas_f32_function
`get_cblas_f32_function() -> cblas_gemm_type` **Returns:** [`cblas_gemm_type`](/mojo/kernels/linalg/matmul/cpu/apple_accelerate/#cblas_gemm_type)
--- ## apple_accelerate
## `comptime` values ### `APPLE_ACCELERATE` `comptime APPLE_ACCELERATE = _Global["APPLE_ACCELERATE", _init_dylib, _on_error_msg]` ### `cblas_gemm_type` `comptime cblas_gemm_type = fn(_CBLASOrder, _CBLASTranspose, _CBLASTranspose, Int32, Int32, Int32, Float32, UnsafePointer[Float32, ImmutAnyOrigin], Int32, UnsafePointer[Float32, ImmutAnyOrigin], Int32, Float32, UnsafePointer[Float32, MutAnyOrigin], Int32) -> None` ### `LIB_ACC_PATH` `comptime LIB_ACC_PATH = "/System/Library/Frameworks/Accelerate.framework/Accelerate"` ## Functions * [​`apple_batched_matmul`](./apple_batched_matmul): * [​`apple_gemv`](./apple_gemv): * [​`apple_matmul`](./apple_matmul): * [​`get_cblas_f32_function`](./get_cblas_f32_function): * [​`use_apple_accelerate_lib`](./use_apple_accelerate_lib):
--- ## use_apple_accelerate_lib
`use_apple_accelerate_lib[c_type: DType, a_type: DType, b_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## Inner_matmul_default
`struct Inner_matmul_default` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## default
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Inner_matmul_default`](./Inner_matmul_default):
--- ## Inner_matmul_i8mm
`struct Inner_matmul_i8mm` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows2, TileN, TileK) tile.
--- ## LoadStore_i8mm
`struct LoadStore_i8mm[dtype: DType, simd_size: Int, single_row: Bool, tile_rows: Int, tile_columns: Int]` ## Fields * ​output\_tile (`_Accumulator[dtype, tile_rows, LoadStore_i8mm[dtype, simd_size, single_row, tile_rows, tile_columns].num_simd_cols, simd_size]`): * ​skip\_boundary\_check (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `num_simd_cols` `comptime num_simd_cols = (tile_columns // simd_size)` ## Methods ### `__init__` `__init__(out self, skip_boundary_check: Bool)`
--- ## i8mm
## Structs * [​`Inner_matmul_i8mm`](./Inner_matmul_i8mm): * [​`LoadStore_i8mm`](./LoadStore_i8mm):
--- ## InnerMatmulKernel
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self: _Self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## TiledMatmul
`struct TiledMatmul[config: KernelConfig, transpose_b: Bool, b_packed: Bool, elementwise_epilogue_enabled: Bool, kernel_id: InnerKernelID, a_type: DType, a_shape: DimList, a_origin: ImmutOrigin, b_type: DType, b_shape: DimList, b_origin: ImmutOrigin, c_type: DType, c_shape: DimList, c_origin: MutOrigin, algorithm: InnerMatmulKernel]` Tiled matmul implementation integrating packing, inner loop and tile partitions. TODO: add tag based implementation dispatch. TODO: add fusion hooks. ## Fields * ​alg (`algorithm`): * ​c (`NDBuffer[c_type, 2, c_origin, c_shape]`): * ​a (`NDBuffer[a_type, 2, a_origin, a_shape]`): * ​b (`NDBuffer[b_type, 2, b_origin, b_shape]`): * ​tile\_n\_k (`IndexList[2]`): * ​global\_tile\_offset (`GemmShape`): * ​global\_tile\_shape (`GemmShape`): * ​b\_tile\_generator (`BTileGenerator[config, a_type, b_type, c_type, b_shape, transpose_b, b_packed, b_origin]`): * ​elementwise\_epilogue\_fn (`fn(GemmShape, GemmShape) escaping -> None`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__copyinit__is_trivial if True if True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__del__is_trivial if True if True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__moveinit__is_trivial if True if True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial`
--- ## elementwise_epilogue_c_tile
`elementwise_epilogue_c_tile[simd_width: Int, dtype: DType, origin: MutOrigin, c_shape: DimList, func: fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None](offset: GemmShape, tile_len: GemmShape, c: NDBuffer[dtype, 2, origin, c_shape])`
--- ## impl
## Structs * [​`TiledMatmul`](./TiledMatmul): Tiled matmul implementation integrating packing, inner loop and tile partitions. ## Traits * [​`InnerMatmulKernel`](./InnerMatmulKernel): ## Functions * [​`elementwise_epilogue_c_tile`](./elementwise_epilogue_c_tile): * [​`matmul`](./matmul): * [​`tiled_matmul_run`](./tiled_matmul_run): Interface function to run tiled matmul on a given sub-tile.
--- ## matmul
`matmul[*, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], kernel_type_m: Int, num_threads: Int = -1)`
--- ## tiled_matmul_run
`tiled_matmul_run[config: KernelConfig, transpose_b: Bool, b_packed: Bool, simd_size: Int, elementwise_epilogue_enabled: Bool, kernel_id: InnerKernelID, algorithm: InnerMatmulKernel](alg: algorithm, c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], elementwise_epilogue_fn: fn(GemmShape, GemmShape) escaping -> None, global_tile_shape: GemmShape, global_tile_offset: GemmShape)` Interface function to run tiled matmul on a given sub-tile. **Args:** * ​alg (`algorithm`): InnerMatmulKernel algorithm for microkernel. * ​c ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Pre-allocated buffer space for result. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Operand A of the matmul. * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Operand B of the mamtul. * ​elementwise\_epilogue\_fn (\[`fn(GemmShape, GemmShape) escaping -> None`]\(/mojo/kernels/linalg/matmul/cpu/impl/fn(GemmShape, GemmShape) escaping -> None)): The elementwise epilogue function. * ​global\_tile\_shape ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Tile shape this call will process. * ​global\_tile\_offset ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Tile offset on the original buffer.
--- ## cpu (Cpu)
Provides the CPU backend implementations for matmuls. ## Modules * [​`apple_accelerate`](./apple_accelerate/): * [​`default`](./default/): * [​`i8mm`](./i8mm/): * [​`impl`](./impl/): * [​`neon`](./neon/): * [​`vnni`](./vnni/):
--- ## Inner_matmul_neon
`struct Inner_matmul_neon` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## neon
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Inner_matmul_neon`](./Inner_matmul_neon):
--- ## Inner_matmul_vnni
`struct Inner_matmul_vnni[saturated_vnni: Bool]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## vnni
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Inner_matmul_vnni`](./Inner_matmul_vnni):
--- ## amd
Provides the AMD GPU backend implementations for matmuls. ## Modules * [​`matmul`](./matmul/): * [​`pingpong_kernel`](./pingpong_kernel/): * [​`ring_buffer`](./ring_buffer/): Ring Buffer implementation for producer-consumer synchronization in GPU kernels. * [​`ring_buffer_traits`](./ring_buffer_traits/): Trait definitions and utilities for ring buffer synchronization strategies. * [​`structured`](./structured/): * [​`warp_spec_matmul`](./warp_spec_matmul/): AMD Warp-Specialized Matrix Multiplication
--- ## MMATileBuffers
`struct MMATileBuffers[mut: Bool, dtype: DType, layout: Layout, origin: Origin[mut=mut], address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, //, _dtype: DType, /, smem_layout: Layout, reg_tile_layout: Layout, tensor_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]], thread_layout: Layout, warp_rows: Int, warp_cols: Int, swizzle: Swizzle]` Manages memory for a single matrix (A or B) in GEMM computation. This struct encapsulates all memory handling for a matrix, including: * Shared memory allocation and tiling * Register buffer allocation * Data movement between memory levels (DRAM→local→shared) ## Fields * ​smem\_tile (`MMATileBuffers[_dtype, smem_layout, reg_tile_layout, tensor_type, thread_layout, warp_rows, warp_cols, swizzle].SMemTile`): * ​smem\_warp\_tile (`LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(smem_layout, AddressSpace.SHARED), _get_index_type(smem_layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(smem_layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(smem_layout, AddressSpace.SHARED), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols]()]`): * ​load\_reg\_tile (`MMATileBuffers[_dtype, smem_layout, reg_tile_layout, tensor_type, thread_layout, warp_rows, warp_cols, swizzle].MMARegTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `MMARegTile` `comptime MMARegTile = LayoutTensor[_dtype, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `SMemTile` `comptime SMemTile = LayoutTensor[_dtype, smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, warp_k_idx: Int, block_idx: Int)` Initialize memory regions for a matrix based on warp coordinates. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to load from global memory. * ​warp\_idx ([`Int`](/mojo/std/builtin/int/Int)): The warp index within the computation grid (used for MMA operations). * ​warp\_k\_idx ([`Int`](/mojo/std/builtin/int/Int)): The warp index within the computation grid (used for MMA operations). * ​block\_idx ([`Int`](/mojo/std/builtin/int/Int)): The block index within the computation grid (used for warp tiling). ### `copy_to_smem` `copy_to_smem(self)` Copy data from thread-local memory to shared memory. Uses structured thread cooperation to efficiently transfer data.
--- ## MmaOpAMD
`struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, out_frag_size: Int, swizzle: Swizzle]` ## Fields * ​out\_reg\_tile (`MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].OutRegTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `alignment` `comptime alignment = align_of[SIMD[in_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width]]()` ### `out_reg_layout` `comptime out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), out_frag_size)` ### `OutRegTile` `comptime OutRegTile = LayoutTensor[out_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `reg_tile_layout` `comptime reg_tile_layout[num_mmas: Int] = Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width)` #### Parameters * ​num\_mmas ([`Int`](/mojo/std/builtin/int/Int)): ### `RegTile` `comptime RegTile[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` #### Parameters * ​num\_mmas ([`Int`](/mojo/std/builtin/int/Int)): ### `simd_width` `comptime simd_width = simd_width_of[in_type]()` ### `tensor_core_mma` `comptime tensor_core_mma = TiledTensorCore[out_type, in_type, shape, k_group_size, transpose_b]()` ## Methods ### `__init__` `__init__(out self)` ### `a_reg_tile` `a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), False, align_of[in_type](), num_m_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_m_mmas, simd_width_of[in_type]()]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `b_reg_tile` `b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), False, align_of[in_type](), num_n_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_n_mmas, simd_width_of[in_type]()]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `mma` `mma[k_tile_idx: Int](self)` ### `load_tile_fragment` `load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()], b_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()])` ### `reset_accumulator` `reset_accumulator(self)`
--- ## gemm_kernel_amd
`gemm_kernel_amd[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, c_layout_int_type: DType, a_layout_int_type: DType, b_layout_int_type: DType, c_linear_idx_type: DType, a_linear_idx_type: DType, b_linear_idx_type: DType, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin, layout_int_type=c_layout_int_type, linear_idx_type=c_linear_idx_type], a: LayoutTensor[a_type, a_layout, MutAnyOrigin, layout_int_type=a_layout_int_type, linear_idx_type=a_linear_idx_type], b: LayoutTensor[b_type, b_layout, MutAnyOrigin, layout_int_type=b_layout_int_type, linear_idx_type=b_linear_idx_type])` AMD-optimized GEMM kernel for matrix multiplication C = A \* B. This kernel implements an efficient matrix multiplication algorithm optimized for AMD GPUs, with hierarchical tiling and structured memory access patterns. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the output matrix C. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix C. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the input matrix A. * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix A. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the input matrix B. * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix B. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether matrix B should be transposed. * ​c\_layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the integer part of matrix C. * ​a\_layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the integer part of matrix A. * ​b\_layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the integer part of matrix B. * ​c\_linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the linear index of matrix C. * ​a\_linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the linear index of matrix A. * ​b\_linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the linear index of matrix B. * ​config ([`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)): GEMM configuration parameters (tile sizes, etc.). * ​elementwise\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional function to apply to output elements. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C (result). * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix B (must be transposed).
--- ## matmul (Matmul)
## `comptime` values ### `SMemWarpTileType` `comptime SMemWarpTileType[_dtype: DType, layout: Layout, warp_rows: Int, warp_cols: Int] = LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()]` Type alias for warp-level shared memory tiles with specified dimensions. #### Parameters * ​\_dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​warp\_rows ([`Int`](/std/builtin/int/Int)): * ​warp\_cols ([`Int`](/std/builtin/int/Int)): ## Structs * [​`MmaOpAMD`](./MmaOpAMD): * [​`MMATileBuffers`](./MMATileBuffers): Manages memory for a single matrix (A or B) in GEMM computation. ## Functions * [​`gemm_kernel_amd`](./gemm_kernel_amd): AMD-optimized GEMM kernel for matrix multiplication C = A \* B. * [​`write_output_fragments`](./write_output_fragments): Write output fragments from registers to global memory with optional elementwise operations.
--- ## write_output_fragments
`write_output_fragments[c_type: DType, c_frag_size: Int, MMA_M: Int, MMA_N: Int, output_thread_layout: Layout, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c_reg_fragment: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_gmem_fragment: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_tile_m: Int, warp_tile_n: Int, M: Int, N: Int)` Write output fragments from registers to global memory with optional elementwise operations. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for the output matrix C. * ​c\_frag\_size ([`Int`](/mojo/std/builtin/int/Int)): Size of each output fragment. * ​MMA\_M ([`Int`](/mojo/std/builtin/int/Int)): Matrix multiply instruction M dimension. * ​MMA\_N ([`Int`](/mojo/std/builtin/int/Int)): Matrix multiply instruction N dimension. * ​output\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread layout for output operations. * ​elementwise\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional elementwise operation to apply. **Args:** * ​c\_reg\_fragment ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register fragments containing computation results. * ​c\_gmem\_fragment ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Global memory fragment for output. * ​warp\_tile\_m ([`Int`](/mojo/std/builtin/int/Int)): M coordinate of the warp tile. * ​warp\_tile\_n ([`Int`](/mojo/std/builtin/int/Int)): N coordinate of the warp tile. * ​M ([`Int`](/mojo/std/builtin/int/Int)): Total M dimension of the output matrix. * ​N ([`Int`](/mojo/std/builtin/int/Int)): Total N dimension of the output matrix.
--- ## AMDPingPongMatmul
`struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, config: KernelConfig, /, enable_swizzle: Bool]` 8-warp ping-pong matmul for AMD MI355X. Warps are split into 2 groups of 4, alternating between load and compute phases for overlapped execution. Uses double-buffered LDS with swizzled access patterns to avoid bank conflicts. Key features: * load\_to\_lds for direct DRAM→LDS transfer (bypasses L1/L2) * Swizzle pattern for bank-conflict-free LDS access * Fine-grained lgkmcnt/vmcnt waits for maximum overlap ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_dtype` `comptime accum_dtype = get_accum_type[c_type]()` ### `accum_width` `comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_N) // WARP_SIZE)` ### `BK` `comptime BK = config.block_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_shape.__getitem__[3, DType.int64, Int](1)` ### `LGKM_PER_LOAD_A` `comptime LGKM_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_k_mmas)` ### `LGKM_PER_LOAD_AB` `comptime LGKM_PER_LOAD_AB = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].LGKM_PER_LOAD_A + AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].LGKM_PER_LOAD_B)` ### `LGKM_PER_LOAD_B` `comptime LGKM_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_k_mmas)` ### `load_width` `comptime load_width = simd_width_of[a_type]()` ### `loading_threads_4warp` `comptime loading_threads_4warp = (4 * WARP_SIZE)` ### `loading_threads_8warp` `comptime loading_threads_8warp = (8 * WARP_SIZE)` ### `loads_per_row` `comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].load_width)` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accums` `comptime num_accums = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_n_mmas)` ### `num_k_mmas` `comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_N)` ### `num_warps_m` `comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WM)` ### `num_warps_n` `comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WN)` ### `ping_pong_stages` `comptime ping_pong_stages = 2` ### `quadrant_m_mmas` `comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_m_mmas // 2)` ### `quadrant_n_mmas` `comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_n_mmas // 2)` ### `rows_per_iter_4warp` `comptime rows_per_iter_4warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loading_threads_4warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loads_per_row)` ### `rows_per_iter_8warp` `comptime rows_per_iter_8warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loading_threads_8warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loads_per_row)` ### `total_smem_a` `comptime total_smem_a = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK)` ### `total_smem_b` `comptime total_smem_b = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK)` ### `total_warps` `comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_warps_n)` ### `VMCNT_PER_LOAD_A` `comptime VMCNT_PER_LOAD_A = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_8warp)` ### `VMCNT_PER_LOAD_A_4WARP` `comptime VMCNT_PER_LOAD_A_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_4warp)` ### `VMCNT_PER_LOAD_B` `comptime VMCNT_PER_LOAD_B = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_8warp)` ### `VMCNT_PER_LOAD_B_4WARP` `comptime VMCNT_PER_LOAD_B_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_4warp)` ### `WK` `comptime WK = config.warp_shape.__getitem__[3, DType.int64, Int](2)` ### `WM` `comptime WM = config.warp_shape.__getitem__[3, DType.int64, Int](0)` ### `WN` `comptime WN = config.warp_shape.__getitem__[3, DType.int64, Int](1)` ## Methods ### `validate_config` `static validate_config()` Validate the kernel configuration. ### `matmul_ping_pong` `static matmul_ping_pong(a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])`
--- ## KernelConfig
`struct KernelConfig` ## Fields * ​block\_shape (`IndexList[3]`): * ​warp\_shape (`IndexList[3]`): * ​mma\_shape (`IndexList[3]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, *, block_shape: IndexList[3], warp_shape: IndexList[3], mma_shape: IndexList[3])` ### `num_threads` `num_threads(self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `write_to` `write_to(self, mut writer: T)` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String)
--- ## MmaOp
`struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, swizzle: Optional[Swizzle]]` Encapsulates MMA register tiles and operations for matrix multiplication. This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture. MmaOp accepts generic SMemTile and validates compatibility at compile-time via load\_lds\_fragment constraints. Note: Several values are derived from other parameters: * num\_m\_mmas = WM // MMA\_M * num\_n\_mmas = WN // MMA\_N * num\_k\_mmas = BK // MMA\_K * load\_width = simd\_width\_of[in\_type]() (SIMD width for input type) * accum\_width = (MMA\_M \* MMA\_N) // WARP\_SIZE (elements per thread) Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling: * quadrant\_m\_mmas = num\_m\_mmas // 2 (M-dimension quadrant size) * quadrant\_n\_mmas = num\_n\_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes. Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane\_id() Swizzle Configuration: MmaOp receives the swizzle pattern from the kernel/TileBuffers, since it's determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing. * BF16: Swizzle(1, 5, 4) - 1 bit XOR * FP8 16×128: Swizzle(3, 4, 4) - 3 bit XOR (HipKittens st\_16x128) ## Fields * ​a\_reg\_tile (`MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ARegTile`): * ​b\_reg\_tile (`MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].BRegTile`): * ​out\_reg\_tile (`MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].OutRegTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_width` `comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)` ### `ARegTile` `comptime ARegTile = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `BRegTile` `comptime BRegTile = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `bytes_per_frag` `comptime bytes_per_frag = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width * size_of[in_type]())` ### `col_groups` `comptime col_groups = (WARP_SIZE // MMA_M)` ### `ds_reads_per_frag` `comptime ds_reads_per_frag = ceildiv(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].bytes_per_frag, 16)` ### `elem_swizzle` `comptime elem_swizzle = swizzle` ### `k_loads_per_mma` `comptime k_loads_per_mma = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width // MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width)` ### `lds_frag_width` `comptime lds_frag_width = 16 if MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].use_fp8_16x16x128_mma else MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width` ### `lgkm_per_load_a` `comptime lgkm_per_load_a = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)` ### `lgkm_per_load_ab` `comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_b)` ### `lgkm_per_load_b` `comptime lgkm_per_load_b = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)` ### `load_width` `comptime load_width = simd_width_of[in_type]()` ### `mma_access_layout` `comptime mma_access_layout = Layout(IntTuple(MMA_M, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].col_groups), IntTuple(MMA_K, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width))` ### `mma_frag_width` `comptime mma_frag_width = ((MMA_M * MMA_K) // WARP_SIZE)` ### `num_k_mmas` `comptime num_k_mmas = (BK // MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (WM // MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (WN // MMA_N)` ### `out_reg_layout` `comptime out_reg_layout = Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width))` ### `OutRegTile` `comptime OutRegTile = LayoutTensor[accum_type, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `quadrant_m_mmas` `comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas // 2)` ### `quadrant_m_size` `comptime quadrant_m_size = MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas` ### `quadrant_n_mmas` `comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas // 2)` ### `quadrant_n_size` `comptime quadrant_n_size = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width)` ### `RegTile` `comptime RegTile[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` #### Parameters * ​num\_mmas ([`Int`](/mojo/std/builtin/int/Int)): ### `use_fp8_16x16x128_mma` `comptime use_fp8_16x16x128_mma = (MMA_K == 128) if (MMA_M == 16) if (in_type == DType.float8_e4m3fn)._mlir_value else (in_type == DType.float8_e4m3fn) else (MMA_M == 16) if (in_type == DType.float8_e4m3fn)._mlir_value else (in_type == DType.float8_e4m3fn)` ### `use_fp8_32x32x64_mma` `comptime use_fp8_32x32x64_mma = (MMA_K == 64) if (MMA_M == 32) if (in_type == DType.float8_e4m3fn)._mlir_value else (in_type == DType.float8_e4m3fn) else (MMA_M == 32) if (in_type == DType.float8_e4m3fn)._mlir_value else (in_type == DType.float8_e4m3fn)` ## Methods ### `__init__` `__init__(out self)` Initialize MMA operation with register tiles. ### `reset_accumulator` `reset_accumulator(self)` Reset output register tile to zero. ### `load_a` `load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load A\[which] from LDS → registers. Accepts SMemTile with matching dtype - layout compatibility validated at compile-time via load\_lds\_fragment constraints. For FP8 16×16×128: Uses lds\_frag\_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds\_frag\_width=32 with single load. For BF16 16×16×32: Uses lds\_frag\_width=8. ### `load_b` `load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load B\[which] from LDS → registers. Accepts SMemTile with matching dtype - layout compatibility validated at compile-time via load\_lds\_fragment constraints. For FP8 16×16×128: Uses lds\_frag\_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds\_frag\_width=32 with single load. For BF16 16×16×32: Uses lds\_frag\_width=8. ### `mma` `mma[which_a: Int, which_b: Int](self)` Execute MMA operations for a quadrant of the output tile. Accesses quadrant via .tile\[] view into the contiguous out\_reg\_tile. Uses mma\_frag\_width for fragment sizing (4 for BF16, 8 for FP8). Works for both BF16 and FP8 via stdlib mma() dispatch. **Parameters:** * ​which\_a ([`Int`](/mojo/std/builtin/int/Int)): A quadrant index (0 or 1). * ​which\_b ([`Int`](/mojo/std/builtin/int/Int)): B quadrant index (0 or 1).
--- ## TileBuffers
`struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, MMA_K: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]` Double-buffered LDS tiles and TileLoaders for ping-pong matmul. a\_layout and b\_layout are infer-only parameters (note `//`), automatically extracted from the input tensors passed to **init**. K is derived as an comptime from a\_layout.shape\[1]. ## Fields * ​a\_mma\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]`): * ​b\_mma\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]`): * ​a\_load\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]`): * ​b\_load\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]`): * ​loader\_a (`TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader`): * ​loader\_b (`TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader`): * ​warp\_id\_m (`Int`): * ​warp\_shift\_rows (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `AHalfTile` `comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ### `AHalfTilePair` `comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]` ### `AMmaTile` `comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), False, alignment, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]` ### `AMmaTilePair` `comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]` ### `ATileLoader` `comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]` ### `BHalfTile` `comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ### `BHalfTilePair` `comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]` ### `BMmaTile` `comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK](), alignment, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() if _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]` ### `BMmaTilePair` `comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]` ### `BTileLoader` `comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]` ### `byte_swizzle` `comptime byte_swizzle = Optional[Swizzle](Swizzle(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_log_tile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_base, 4)) if enable_swizzle else Optional[Swizzle]()` ### `elem_size` `comptime elem_size = size_of[in_type]()` ### `elements_per_warp` `comptime elements_per_warp = (WARP_SIZE * load_width)` ### `frag_bytes` `comptime frag_bytes = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].lds_frag_width * size_of[in_type]())` ### `half_BM` `comptime half_BM = (BM // 2)` ### `half_BN` `comptime half_BN = (BN // 2)` ### `half_tile_layout` `comptime half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)` ### `HalfTile` `comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` #### Parameters * ​rows ([`Int`](/mojo/std/builtin/int/Int)): ### `K` `comptime K = a_layout.shape[1].value()` ### `lds_frag_width` `comptime lds_frag_width = 16 if TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_split_k else TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_frag_width` ### `loading_threads` `comptime loading_threads = (loading_warps * WARP_SIZE)` ### `loads_per_row` `comptime loads_per_row = (BK // load_width)` ### `mma_frag_width` `comptime mma_frag_width = ((16 * MMA_K) // WARP_SIZE)` ### `mma_tile_m` `comptime mma_tile_m = (WM // 2)` ### `mma_tile_n` `comptime mma_tile_n = (WN // 2)` ### `rows_per_iter_4warp` `comptime rows_per_iter_4warp = ((4 * WARP_SIZE) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)` ### `rows_per_load_iteration` `comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)` ### `rows_per_warp` `comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)` ### `smem_ptr` `comptime smem_ptr = LegacyUnsafePointer[Scalar[in_type], address_space=AddressSpace.SHARED]` ### `SMemTile` `comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` #### Parameters * ​rows ([`Int`](/mojo/std/builtin/int/Int)): * ​cols ([`Int`](/mojo/std/builtin/int/Int)): ### `swizzle_base` `comptime swizzle_base = log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].frag_bytes) if (in_type == DType.float8_e4m3fn)._mlir_value else (log2_floor((TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2)) + log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))` ### `swizzle_log_tile` `comptime swizzle_log_tile = (log2_floor((MMA_K // 32)) + 1)` ### `swizzle_shift` `comptime swizzle_shift = 4` ### `swizzle_subtile_cols` `comptime swizzle_subtile_cols = (4 * load_width)` ### `TileLoader` `comptime TileLoader[src_layout: Layout] = TileLoaderLDS[in_type, src_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]` #### Parameters * ​src\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ### `total_warps` `comptime total_warps = 8` ### `use_fp8_row_major` `comptime use_fp8_row_major = (in_type == DType.float8_e4m3fn)` ### `use_split_k` `comptime use_split_k = (MMA_K == 128) if (in_type == DType.float8_e4m3fn)._mlir_value else (in_type == DType.float8_e4m3fn)` ### `vmcnt_per_load_a` `comptime vmcnt_per_load_a = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)` ### `vmcnt_per_load_a_4warp` `comptime vmcnt_per_load_a_4warp = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)` ### `vmcnt_per_load_ab` `comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)` ### `vmcnt_per_load_b` `comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)` ### `vmcnt_per_load_b_4warp` `comptime vmcnt_per_load_b_4warp = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)` ## Methods ### `__init__` `__init__(out self, a: LayoutTensor[in_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)` Initialize LDS tiles and loaders. Layouts inferred from a and b tensors. ### `load_a` `load_a[stage: Int, which: Int](self, *, k: Int)` Load A\[stage]\[which] from global to LDS using all 8 warps. ### `load_b` `load_b[stage: Int, which: Int](self, *, k: Int)` Load B\[stage]\[which] from global to LDS using all 8 warps. ### `load_a_as_group` `load_a_as_group[stage: Int, target_group: Int](self, caller_group: Int, *, k: Int)` Load A\[stage]\[target\_group] from global to LDS using 4 warps. ### `load_b_as_group` `load_b_as_group[stage: Int, which: Int](self, caller_group: Int, loading_group: Int, *, k: Int)` Load B\[stage]\[which] from global to LDS using 4 warps.
--- ## TileLoaderLDS
`@register_passable(trivial)` `struct TileLoaderLDS[dtype: DType, src_layout: Layout, src_tile_layout: Layout, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional[Swizzle](), load_width: Int = simd_width_of[dtype](), use_full_tile_width: Bool = False]` Cooperative global→LDS tile loader with swizzle support. Loads tiles from global memory to LDS using AMDBufferResource which provides automatic out-of-bounds clamping to zero - critical for partial block support. Loading Modes (controlled by use\_full\_tile\_width): * False (default): Interleaved layout. Each warp handles 32-col subtile. Used for BF16 where MMA\_K (32) < BK (64). * True: Row-major layout. Each source row maps 1:1 to LDS row. Used for FP8 where MMA\_K == BK, enabling correct partial block handling. ## Fields * ​buffer (`AMDBufferResource`): * ​thread\_row (`Int`): * ​thread\_col (`Int`): * ​warp\_id (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `elements_per_warp` `comptime elements_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp * load_width)` ### `loading_threads` `comptime loading_threads = (num_loading_warps * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp)` ### `loads_per_row` `comptime loads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // load_width)` ### `num_iterations` `comptime num_iterations = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)` ### `num_warp_cols` `comptime num_warp_cols = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)` ### `num_warp_rows` `comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)` ### `rows_per_iteration` `comptime rows_per_iteration = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loads_per_row)` ### `rows_per_warp` `comptime rows_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols)` ### `stride` `comptime stride = src_layout.shape[1].value()` ### `subtile_cols` `comptime subtile_cols = TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols if use_full_tile_width else 32` ### `thread_rows` `comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)` ### `threads_per_row` `comptime threads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)` ### `threads_per_warp` `comptime threads_per_warp = WARP_SIZE` ### `tile_cols` `comptime tile_cols = src_tile_layout.shape[1].value()` ### `tile_rows` `comptime tile_rows = src_tile_layout.shape[0].value()` ## Methods ### `__init__` `__init__(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_id: Int, lane_id: Int) -> Self` Pre-compute thread position with swizzle inversion for bank-conflict-free reads. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor (tile view for BF16, full tensor for FP8). * ​warp\_id ([`Int`](/mojo/std/builtin/int/Int)): Warp ID within the block. * ​lane\_id ([`Int`](/mojo/std/builtin/int/Int)): Lane ID within the warp. ### `load_tile` `load_tile[dst_layout: Layout, //](self, dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_row: Int, src_col: Int)` Load a tile from source coordinates to LDS. Combines pre-computed thread position with source coordinates. Uses the buffer resource stored at init time. Only warps 0 to num\_loading\_warps-1 participate; others return immediately. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination LDS tile. * ​src\_row ([`Int`](/mojo/std/builtin/int/Int)): Starting row in source tensor. * ​src\_col ([`Int`](/mojo/std/builtin/int/Int)): Starting column in source tensor (typically k\_offset).
--- ## pingpong_kernel
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`AMDPingPongMatmul`](./AMDPingPongMatmul): 8-warp ping-pong matmul for AMD MI355X. * [​`KernelConfig`](./KernelConfig): * [​`MmaOp`](./MmaOp): Encapsulates MMA register tiles and operations for matrix multiplication. * [​`TileBuffers`](./TileBuffers): Double-buffered LDS tiles and TileLoaders for ping-pong matmul. * [​`TileLoaderLDS`](./TileLoaderLDS): Cooperative global→LDS tile loader with swizzle support. ## Functions * [​`load_lds_fragment`](./load_lds_fragment): Load LDS → registers with MMA access pattern. * [​`make_mma_swizzle`](./make_mma_swizzle): Create swizzle pattern for MMA LDS access. * [​`ping_pong_matmul`](./ping_pong_matmul):
--- ## load_lds_fragment
`load_lds_fragment[dtype: DType, smem_layout: Layout, smem_element_layout: Layout, frag_layout: Layout, frag_element_layout: Layout, //, mma_access_layout: Layout, swizzle: Optional[Swizzle] = Optional[Swizzle]()](smem_tile: LayoutTensor[dtype, smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=smem_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], reg_frag: LayoutTensor[dtype, frag_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=frag_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load LDS → registers with MMA access pattern. Why mma\_access\_layout differs from the global→LDS thread layout: ┌─────────────────────────────────────────────────────────────────────┐ │ Layout │ Purpose │ Constraint │ ├─────────────────────────────────────────────────────────────────────┤ │ load\_thread │ Global → LDS write │ Coalesced global reads │ │ mma\_access │ LDS → Registers read │ AMD WMMA hardware pattern │ └─────────────────────────────────────────────────────────────────────┘ mma\_access\_layout encodes how AMD's WMMA instruction expects data: * Lane decomposition: (lane % 16, lane // 16) = (col\_group, row\_group) * Offset computation: col\_group \* 32 + row\_group \* 8 Using RuntimeLayout ensures compile-time evaluation (no GPU heap alloc). Layout compatibility requirements: * mma\_access\_layout must map exactly WARP\_SIZE (64) threads * smem must have enough elements for: num\_iterations \* WARP\_SIZE \* frag\_width * frag must store: num\_iterations \* frag\_width elements
--- ## make_mma_swizzle
`make_mma_swizzle[dtype: DType, MMA_M: Int, MMA_K: Int]() -> Swizzle` Create swizzle pattern for MMA LDS access. AMD MI355X have 64 LDS banks × 4 bytes each. Without swizzling, the MMA thread access pattern causes 4-way bank conflicts. The swizzle XORs high-order address bits into the bank selection bits to distribute accesses across banks. Swizzle parameters: * log\_tile: Number of bits to XOR, scales with MMA\_K * base: Log2 of read granularity in bytes (lds\_frag\_width \* elem\_size) * shift: Fixed at 4 for AMD LDS bank geometry Configuration examples: BF16 16×16×32: lds\_frag=8 bytes=16 → Swizzle(1, 4, 4) FP8 16×16×128: lds\_frag=16 bytes=16 → Swizzle(3, 4, 4) FP8 32×32×64: lds\_frag=32 bytes=32 → Swizzle(2, 5, 4) **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type (affects byte size). * ​MMA\_M ([`Int`](/mojo/std/builtin/int/Int)): M dimension of MMA instruction. * ​MMA\_K ([`Int`](/mojo/std/builtin/int/Int)): K dimension of MMA instruction. **Returns:** [`Swizzle`](/mojo/kernels/layout/swizzle/Swizzle): Swizzle pattern for bank-conflict-free LDS access.
--- ## ping_pong_matmul
`ping_pong_matmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, //, enable_swizzle: Bool = True](a_device_tensor: LayoutTensor[a_type, a_layout, origin], b_device_tensor: LayoutTensor[b_type, b_layout, origin], c_device_tensor: LayoutTensor[c_type, c_layout, origin], ctx: DeviceContext)`
--- ## ConsumerTile
`@register_passable(trivial)` `struct ConsumerTile[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_computed_per_consumer: Int]` Context manager for consumer access to a single ring buffer tile. ## Fields * ​consumer\_view\_ptr (`ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewPtrType`): * ​stage (`Int`): * ​consumer\_iteration (`Int`): * ​warp\_tile\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ConsumerViewPtrType` `comptime ConsumerViewPtrType = Pointer[ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewType, origin]` ### `ConsumerViewType` `comptime ConsumerViewType = ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer]` ## Methods ### `__init__` `__init__(consumer_view_ptr: Pointer[ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewType, origin], stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> Self` ### `__enter__` `__enter__(mut self) -> ring_buffer_type.WarpTileTupleType` Acquire the tile for use. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `__exit__` `__exit__(mut self)` Release the tile back to producers.
--- ## ConsumerView
`@register_passable(trivial)` `struct ConsumerView[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_computed_per_consumer: Int]` Consumer view of the unified ring buffer. ## Fields * ​ring\_buffer\_ptr (`ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer].RingBufferPtrType`): * ​phases (`StaticTuple[Int32, (pipeline_stages * warps_computed_per_consumer)]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ConsumerTileType` `comptime ConsumerTileType = ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer]` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` Context manager entry. ### `__exit__` `__exit__(mut self)` Context manager exit. ### `acquire_tiles` `acquire_tiles(mut self, stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> ring_buffer_type.WarpTileTupleType` Acquire tiles for reading by this consumer. **Args:** * ​stage ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage to read from. * ​consumer\_iteration ([`Int`](/mojo/std/builtin/int/Int)): Which iteration this consumer is on (0 to warps\_computed\_per\_consumer-1). * ​warp\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): Which tile this consumer wants to read. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `release_tiles` `release_tiles(mut self, stage: Int, warp_tile_idx: Int)` Signal to producers that tile is free. ### `get_tile` `get_tile(mut self, stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerTileType` Get a context manager for accessing a tile. **Args:** * ​stage ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage. * ​consumer\_iteration ([`Int`](/mojo/std/builtin/int/Int)): Current iteration of this consumer. * ​warp\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): Which tile to access. **Returns:** `ConsumerView`
--- ## ProducerTile
`@register_passable(trivial)` `struct ProducerTile[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_processed_per_producer: Int]` Context manager for producer access to a single ring buffer tile. ## Fields * ​producer\_view\_ptr (`ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewPtrType`): * ​stage (`Int`): * ​producer\_iteration (`Int`): * ​warp\_tile\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ProducerViewPtrType` `comptime ProducerViewPtrType = Pointer[ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewType, origin]` ### `ProducerViewType` `comptime ProducerViewType = ProducerView[origin, ring_buffer_type, warps_processed_per_producer]` ## Methods ### `__init__` `__init__(producer_view_ptr: Pointer[ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewType, origin], stage: Int, producer_iteration: Int, warp_tile_idx: Int) -> Self` ### `__enter__` `__enter__(mut self) -> ring_buffer_type.WarpTileTupleType` Acquire the tile for use. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `__exit__` `__exit__(mut self)` Release the tile back to consumers.
--- ## ProducerView
`@register_passable(trivial)` `struct ProducerView[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_processed_per_producer: Int]` Producer view of the unified ring buffer. ## Fields * ​ring\_buffer\_ptr (`ProducerView[origin, ring_buffer_type, warps_processed_per_producer].RingBufferPtrType`): * ​phases (`StaticTuple[Int32, (pipeline_stages * warps_processed_per_producer)]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ProducerTileType` `comptime ProducerTileType = ProducerTile[origin, ring_buffer_type, warps_processed_per_producer]` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` Context manager entry. ### `__exit__` `__exit__(mut self)` Context manager exit. ### `acquire_tiles` `acquire_tiles(mut self, stage: Int, producer_iteration: Int, warp_tile_idx: Int) -> ring_buffer_type.WarpTileTupleType` Acquire tiles for writing by this producer. **Args:** * ​stage ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage to write to. * ​producer\_iteration ([`Int`](/mojo/std/builtin/int/Int)): Which iteration this producer is on (`0` to `warps_processed_per_producer - 1`). * ​warp\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): Which tile this producer is responsible for. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `release_tiles` `release_tiles(mut self, stage: Int, warp_tile_idx: Int)` Signal to consumers that tile is ready. ### `get_tile` `get_tile(mut self, stage: Int, warp_tile_idx: Int, producer_iteration: Int) -> ProducerView[origin, ring_buffer_type, warps_processed_per_producer].ProducerTileType` Get a context manager for accessing a tile. **Args:** * ​stage ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage. * ​warp\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): Which tile to access. * ​producer\_iteration ([`Int`](/mojo/std/builtin/int/Int)): Current iteration of this producer. **Returns:** `ProducerView`
--- ## RingBuffer
`struct RingBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy]` Ring buffer for coordinating producer-consumer warps in matrix multiplication. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for shared memory tiles. * ​pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of stages for software pipelining. * ​block\_rows ([`Int`](/mojo/std/builtin/int/Int)): Number of rows in block-level tiles. * ​block\_cols ([`Int`](/mojo/std/builtin/int/Int)): Number of columns in block-level tiles. * ​warp\_rows ([`Int`](/mojo/std/builtin/int/Int)): Number of rows in warp-level tiles. * ​warp\_cols ([`Int`](/mojo/std/builtin/int/Int)): Number of columns in warp-level tiles. * ​reads\_per\_warp\_block ([`Int`](/mojo/std/builtin/int/Int)): How many consumer warps read each tile. * ​tile\_buffers ([`Int`](/mojo/std/builtin/int/Int)): Number of separate tile buffers (usually 1). * ​sync\_strategy\_type ([`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy)): Synchronization strategy (SingleCounterSync or SplitCounterSync). ## Fields * ​smem\_buffers (`RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SMemBuffersType`): * ​sync\_strategy (`sync_strategy_type`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = sync_strategy_type.__del__is_trivial` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `SMemBuffersType` `comptime SMemBuffersType = StaticTuple[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType, tile_buffers]` ### `SmemBufferType` `comptime SmemBufferType = SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols]` ### `total_tiles` `comptime total_tiles = (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].block_warps * pipeline_stages)` ### `WarpTileTupleType` `comptime WarpTileTupleType = StaticTuple[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileType, tile_buffers]` ### `WarpTileType` `comptime WarpTileType = RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType.WarpTileType` ## Methods ### `__init__` `__init__(out self)` ### `get_tiles` `get_tiles(self, stage: Int, warp_tile_idx: Int) -> RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType` Get tiles from shared memory. **Returns:** `RingBuffer` ### `producer` `producer[warps_processed_per_producer: Int](mut self) -> ProducerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]` Create a producer view of this ring buffer. **Returns:** `ProducerView` ### `consumer` `consumer[warps_computed_per_consumer: Int](mut self) -> ConsumerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]` Create a consumer view of this ring buffer. **Returns:** `ConsumerView` ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` Get the staged index for a tile and stage. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits to acquire a tile. ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` Producer signals it has released a tile. ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits to acquire a tile. ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` Consumer signals it has released a tile. ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` Get the phase increment for producers. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` Get the phase increment for consumers. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32)
--- ## ring_buffer
Ring Buffer implementation for producer-consumer synchronization in GPU kernels. This ring buffer coordinates data transfer between producer warps (loading from global memory) and consumer warps (performing computation) through shared memory tiles. Key features: * Configurable synchronization strategies via the SyncStrategy trait * Pipeline stages for overlapping data transfer and computation * Context managers for automatic acquire/release of tiles * Phase-based synchronization to prevent data races ## Structs * [​`ConsumerTile`](./ConsumerTile): Context manager for consumer access to a single ring buffer tile. * [​`ConsumerView`](./ConsumerView): Consumer view of the unified ring buffer. * [​`ProducerTile`](./ProducerTile): Context manager for producer access to a single ring buffer tile. * [​`ProducerView`](./ProducerView): Producer view of the unified ring buffer. * [​`RingBuffer`](./RingBuffer): Ring buffer for coordinating producer-consumer warps in matrix multiplication.
--- ## SingleCounterSync
`@register_passable(trivial)` `struct SingleCounterSync[pipeline_stages: Int, block_rows: Int, warp_rows: Int, reads_per_warp_block: Int]` Single counter synchronization strategy. Uses one atomic counter per tile that tracks both producer and consumer progress. This is simpler but has higher contention as all warps compete for the same counter. Phase progression: * Each phase advances by (writes\_per\_warp\_block + reads\_per\_warp\_block) * Producers wait for phase N, increment counter by 1 * Consumers wait for phase N+1, increment counter by 1 ## Fields * ​sync\_counter (`SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].SyncCounterArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `SyncCounterArray` `comptime SyncCounterArray = SMemArray[Int32, SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `total_tiles` `comptime total_tiles = (SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].block_warps * pipeline_stages)` ### `writes_per_warp_block` `comptime writes_per_warp_block = 1` ## Methods ### `__init__` `__init__() -> Self` Initialize with internally allocated sync counter. ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32)
--- ## SplitCounterSync
`@register_passable(trivial)` `struct SplitCounterSync[pipeline_stages: Int, block_rows: Int, warp_rows: Int, reads_per_warp_block: Int]` Split counter synchronization strategy. Uses separate producer and consumer counters per tile to reduce atomic contention. Producers only write to producer counters, consumers only write to consumer counters. Phase progression: * Producer phase advances by reads\_per\_warp\_block (waits for N consumers) * Consumer phase advances by writes\_per\_warp\_block (waits for 1 producer) * This asymmetry reflects the 1-producer-to-N-consumers relationship ## Fields * ​producer\_counters (`SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].ProducerCounterArray`): * ​consumer\_counters (`SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].ConsumerCounterArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `ConsumerCounterArray` `comptime ConsumerCounterArray = SMemArray[Int32, SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `ProducerCounterArray` `comptime ProducerCounterArray = SMemArray[Int32, SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `total_tiles` `comptime total_tiles = (SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].block_warps * pipeline_stages)` ### `writes_per_warp_block` `comptime writes_per_warp_block = 1` ## Methods ### `__init__` `__init__() -> Self` Initialize with internally allocated producer and consumer counters. ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits on consumer counter. ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` Producer increments producer counter. ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits on producer counter. ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` Consumer increments consumer counter by 1. ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` Producer phase advances by reads\_per\_warp\_block. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` Consumer phase advances by writes\_per\_warp\_block. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32)
--- ## SyncStrategy
Interface for synchronization strategies between producers and consumers. All methods have the same signature regardless of the specific implementation, allowing the RingBuffer to be parameterized with any conforming strategy. Phase tracking ensures producers and consumers access different tiles: * Producers wait until consumers have finished with a tile (phase N) * Consumers wait until producers have filled a tile (phase N+1) ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__init__` `__init__() -> _Self` Initialize with internally allocated sync counter. **Returns:** `_Self` ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `get_staged_idx` `get_staged_idx(self: _Self, tile_idx: Int, stage: Int) -> Int` Convert tile index and stage to a flat index in the counter arrays. **Args:** * ​tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): Index of the tile within a stage (0 to block\_warps-1). * ​stage ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage (0 to pipeline\_stages-1). **Returns:** [`Int`](/mojo/std/builtin/int/Int): Flat index for accessing synchronization counters. ### `wait_producer_acquire` `wait_producer_acquire(self: _Self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits until it can write to the specified tile. Blocks until all consumers have finished reading from this tile (counter >= phase). ### `signal_producer_release` `signal_producer_release(mut self: _Self, tile_idx: Int, stage: Int)` Producer signals that it has finished writing to the tile. Increments the appropriate counter to notify waiting consumers. ### `wait_consumer_acquire` `wait_consumer_acquire(self: _Self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits until it can read from the specified tile. Blocks until producer has finished writing to this tile (counter >= phase). ### `signal_consumer_release` `signal_consumer_release(mut self: _Self, tile_idx: Int, stage: Int)` Consumer signals that it has finished reading from the tile. Increments the appropriate counter to notify waiting producers. ### `get_producer_phase_increment` `get_producer_phase_increment(self: _Self) -> Int32` Returns how much to advance the producer phase after each acquisition. This determines when producers can reuse a tile after consumers finish. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self: _Self) -> Int32` Returns how much to advance the consumer phase after each acquisition. This determines when consumers can read a tile after producers finish. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## increment_counter_if_first_thread
`increment_counter_if_first_thread(counter: UnsafePointer[Int32, origin, address_space=AddressSpace.SHARED], increment: Int32)` Atomically increment counter, but only from the first thread in warp.
--- ## ring_buffer_traits
Trait definitions and utilities for ring buffer synchronization strategies. This module provides: * SyncStrategy trait: Interface for producer-consumer synchronization protocols * SingleCounterSync: Uses a single atomic counter per tile (original RingBuffer behavior) * SplitCounterSync: Uses separate producer/consumer counters to reduce contention * Atomic utility functions for thread-safe counter operations ## Structs * [​`SingleCounterSync`](./SingleCounterSync): Single counter synchronization strategy. * [​`SplitCounterSync`](./SplitCounterSync): Split counter synchronization strategy. ## Traits * [​`SyncStrategy`](./SyncStrategy): Interface for synchronization strategies between producers and consumers. ## Functions * [​`increment_counter_if_first_thread`](./increment_counter_if_first_thread): Atomically increment counter, but only from the first thread in warp. * [​`wait_for_counter`](./wait_for_counter): Spin-wait until counter reaches threshold.
--- ## wait_for_counter
`wait_for_counter(counter: UnsafePointer[Int32, origin, address_space=AddressSpace.SHARED], threshold: Int32)` Spin-wait until counter reaches threshold.
--- ## AMDSharedMemoryBarrier
`@register_passable(trivial)` `struct AMDSharedMemoryBarrier` ## Fields * ​\_\_repr (`Int32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `initialize` `initialize(ref[MutAnyOrigin, AddressSpace._value._mlir_value] self)` ### `value` `value(ref[AddressSpace._value._mlir_value] self) -> Int32` **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `increment` `increment(ref[MutAnyOrigin, AddressSpace._value._mlir_value] self, warp_id: Int)` ### `wait_until_greater_or_equal_to` `wait_until_greater_or_equal_to(ref[AddressSpace._value._mlir_value] self, v: Int32)`
--- ## AMDWarpSharedMemoryBarrier
`@register_passable(trivial)` `struct AMDWarpSharedMemoryBarrier[size: Int]` ## Fields * ​\_\_repr (`StaticTuple[Int32, size]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `initialize` `initialize(ref[MutAnyOrigin, AddressSpace._value._mlir_value] self)` ### `value` `value(ref[AddressSpace._value._mlir_value] self) -> Int32` **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `increment` `increment(ref[MutAnyOrigin, AddressSpace._value._mlir_value] self, warp_id: Int)` ### `wait_until_greater_or_equal_to` `wait_until_greater_or_equal_to(ref[AddressSpace._value._mlir_value] self, v: Int32)`
--- ## AmdTileOperator
`@register_passable(trivial)` `struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = None, transpose_b: Bool = True]` Manages tensor core operations for matrix multiplication on AMD GPUs. This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores. Requirements: \- warp\_block\_layout\_a.shape\[0] must be divisible by mma\_shape\[0] \- warp\_block\_layout\_b.shape\[0] must be divisible by mma\_shape\[1] \- warp\_block\_layout\_a.shape\[1] must be divisible by mma\_shape\[2] \- warp\_block\_layout\_b.shape\[1] must be divisible by mma\_shape\[2] \- The K dimension must align such that num\_k\_tiles is divisible by k\_group\_size ## Parameters * ​InType ([`DType`](/mojo/std/builtin/dtype/DType)): Input data type. * ​OutType ([`DType`](/mojo/std/builtin/dtype/DType)): Output data type. * ​warp\_block\_layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout for matrix A warp tiles. * ​warp\_block\_layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout for matrix B warp tiles. * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): Shape of the MMA operation \[M, N, K]. * ​swizzle ([`Optional`](/mojo/std/collections/optional/Optional)): Optional swizzle pattern for memory access. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether matrix B is transposed. ## Fields * ​out\_reg\_tile (`AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ARegTile` `comptime ARegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `BRegTile` `comptime BRegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `k_group_size_a` `comptime k_group_size_a = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())` ### `k_group_size_b` `comptime k_group_size_b = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())` ### `k_tile_fragment_index` `comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)` #### Parameters * ​k\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): ### `k_tile_group_index` `comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)` #### Parameters * ​k\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): ### `num_k_tiles` `comptime num_k_tiles = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].WK // mma_shape.__getitem__[3, DType.int64, Int](2))` ### `num_m_mmas` `comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))` ### `num_n_mmas` `comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))` ### `out_frag_size` `comptime out_frag_size = ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](1)) // WARP_SIZE)` ### `OutRegTile` `comptime OutRegTile = LayoutTensor[OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]` ### `OutRegTileFragmentType` `comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), _get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), False, align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]](), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]` ### `simd_width` `comptime simd_width = simd_width_of[InType]()` ### `tensor_core` `comptime tensor_core = TensorCore[OutType, InType, mma_shape, transpose_b]()` ### `total_k_tiles` `comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles` ### `WK` `comptime WK = product(warp_block_layout_a.shape[1])` ## Methods ### `__init__` `__init__() -> Self` ### `a_reg_tile` `a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]` Get A register tile for a specific K tile. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `b_reg_tile` `b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]` Get B register tile for a specific K tile. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `reset_accumulator` `reset_accumulator(self)` Reset the accumulator to zero for a new tile computation. ### `load_tile_fragment` `load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile_b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load fragments from shared memory to registers for a specific K tile. **Parameters:** * ​k\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): K-tile index (0 to total\_k\_tiles-1). **Args:** * ​smem\_tile\_a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile for matrix A. * ​smem\_tile\_b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile for matrix B. ### `mma_compute` `mma_compute[k_tile_idx: Int](self)` Perform matrix multiply-accumulate for a specific K tile. This method assumes fragments are already loaded via load\_tile\_fragment. **Parameters:** * ​k\_tile\_idx ([`Int`](/mojo/std/builtin/int/Int)): K-tile index (0 to total\_k\_tiles-1).
--- ## Enum
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `value` `value(self: _Self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ## Provided methods ### `__eq__` `__eq__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__is__` `__is__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__isnot__` `__isnot__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## MMAConfig
`@register_passable(trivial)` `struct MMAConfig[InType: DType, OutType: DType, mma_shape: IndexList[3], transpose_b: Bool = True]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `k_group_size_a` `comptime k_group_size_a = (MMAConfig[InType, OutType, mma_shape, transpose_b].simd_width // MMAConfig[InType, OutType, mma_shape, transpose_b].registers_per_thread_a)` ### `k_group_size_b` `comptime k_group_size_b = (MMAConfig[InType, OutType, mma_shape, transpose_b].simd_width // MMAConfig[InType, OutType, mma_shape, transpose_b].registers_per_thread_b)` ### `mma` `comptime mma = TensorCore[OutType, InType, mma_shape, transpose_b]()` ### `registers_per_thread_a` `comptime registers_per_thread_a = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]()` ### `registers_per_thread_b` `comptime registers_per_thread_b = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]()` ### `simd_width` `comptime simd_width = simd_width_of[InType]()` ## Methods ### `adjusted_mma_k_shape_a` `static adjusted_mma_k_shape_a() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `adjusted_mma_k_shape_b` `static adjusted_mma_k_shape_b() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## SMemBuffer
`@register_passable(trivial)` `struct SMemBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, BM: Int, BN: Int, WM: Int, WN: Int]` Manages shared memory and returns 2D tile slices of the buffer. ## Fields * ​buffer (`SMemBuffer[dtype, layout, pipeline_stages, BM, BN, WM, WN].SMemTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BlockTileType` `comptime BlockTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN](), alignment=128]` ### `SMemTile` `comptime SMemTile = LayoutTensor[dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `WarpTileType` `comptime WarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN](), 128, WM, WN]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN]() if _tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], WM, WN](), alignment=128]` ## Methods ### `__init__` `__init__() -> Self` ### `get_tile` `get_tile(self, stage: Int) -> SMemBuffer[dtype, layout, pipeline_stages, BM, BN, WM, WN].BlockTileType` **Returns:** `SMemBuffer`
--- ## ThreadRole
`@register_passable(trivial)` `struct ThreadRole` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Enum`](/mojo/kernels/linalg/matmul/gpu/amd/structured/Enum), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CONSUMER` `comptime CONSUMER = ThreadRole(1)` ### `PRODUCER` `comptime PRODUCER = ThreadRole(0)` ### `PRODUCER_CONSUMER` `comptime PRODUCER_CONSUMER = ThreadRole(2)` ## Methods ### `value` `value(self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `__str__` `__str__(self) -> String` Returns the string representation of this algorithm. **Returns:** [`String`](/mojo/std/collections/string/string/String): String: A human-readable string representation of the algorithm. ### `write_to` `write_to[W: Writer](self, mut writer: W)`
--- ## structured
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`AMDSharedMemoryBarrier`](./AMDSharedMemoryBarrier): * [​`AmdTileOperator`](./AmdTileOperator): Manages tensor core operations for matrix multiplication on AMD GPUs. * [​`AMDWarpSharedMemoryBarrier`](./AMDWarpSharedMemoryBarrier): * [​`MMAConfig`](./MMAConfig): * [​`SMemBuffer`](./SMemBuffer): Manages shared memory and returns 2D tile slices of the buffer. * [​`ThreadRole`](./ThreadRole): ## Traits * [​`Enum`](./Enum): ## Functions * [​`pipeline_layout`](./pipeline_layout):
--- ## pipeline_layout
`pipeline_layout[layout: Layout, pipeline_stages: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## determine_thread_role
`determine_thread_role[producer_a_warps: Int, producer_b_warps: Int]() -> Tuple[ThreadRole, Int]` Returns (role, consumer\_warp\_id within role group). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple)
--- ## get_producer_warp_thread_layout
`get_producer_warp_thread_layout[k_tile_size: Int, simd_width: Int, block_rows: Int, block_cols: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## warp_spec_matmul
AMD Warp-Specialized Matrix Multiplication Architecture Overview: * Producer warps: Load tiles from global to shared memory * A producers: Load M×K tiles from matrix A * B producers: Load N×K tiles from matrix B * Consumer warps: Perform matrix multiplication using shared memory tiles * Ring buffer: Coordinates producer-consumer synchronization with barriers Data Flow: 1. Producers load tiles into shared memory stages 2. Barriers ensure data is ready before consumers access it 3. Consumers compute partial results and accumulate 4. Final results written back to global memory Memory Layout: * Shared memory is divided into pipeline stages for overlapping * Each stage contains block tiles that are further divided into warp tiles * Swizzling may be applied to avoid bank conflicts Ring Buffer Configuration: * Uses SingleCounterSync strategy by default (single atomic counter per tile) * Can be changed to SplitCounterSync in the RingBuffer type aliases for reduced contention * The trait-based design allows easy experimentation with different sync strategies ## `comptime` values ### `GlobalTensor` `comptime GlobalTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL]` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Functions * [​`determine_thread_role`](./determine_thread_role): Returns (role, consumer\_warp\_id within role group). * [​`get_producer_warp_thread_layout`](./get_producer_warp_thread_layout): * [​`lgkm_wait`](./lgkm_wait): * [​`run_producer`](./run_producer): Generic producer function for loading matrix tiles from global to shared memory. * [​`smem_tile_layout`](./smem_tile_layout): * [​`validate_config`](./validate_config): Validates the configuration parameters for the matrix multiplication kernel. * [​`warp_specialized_matmul`](./warp_specialized_matmul): * [​`warp_specialized_matmul_kernel`](./warp_specialized_matmul_kernel):
--- ## lgkm_wait
`lgkm_wait()`
--- ## run_producer
`run_producer[dtype: DType, layout: Layout, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, producer_warps: Int, pipeline_stages: Int, k_tile_size: Int, simd_width: Int, warps_processed_per_producer: Int, tile_count: Int, swizzle: Optional[Swizzle]](matrix: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], mut ring_buffer: RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warp_id: Scalar[DType.uint], block_idx_dim: Int)` Generic producer function for loading matrix tiles from global to shared memory.
--- ## smem_tile_layout
`smem_tile_layout[k_tile_size: Int, block_rows: Int, block_cols: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## validate_config
`validate_config[BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, m_warps: Int, n_warps: Int, producer_a: Int, producer_b: Int, consumer: Int]()` Validates the configuration parameters for the matrix multiplication kernel.
--- ## warp_specialized_matmul
`warp_specialized_matmul[M: Int, N: Int, K: Int, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int = 1](a_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(M, K), origin], b_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(N, K), origin], c_device_tensor: LayoutTensor[DType.float32, Layout.row_major(M, N), origin], ctx: DeviceContext)`
--- ## warp_specialized_matmul_kernel
`warp_specialized_matmul_kernel[in_type: DType, out_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int](a: LayoutTensor[in_type, a_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], b: LayoutTensor[in_type, b_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], c: LayoutTensor[out_type, c_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL])`
--- ## gpu (Gpu)
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Packages * [​`amd`](./amd/): Provides the AMD GPU backend implementations for matmuls. * [​`sm100`](./sm100/): Provides the Nvidia Blackwell backend implementations for matmuls. * [​`sm100_structured`](./sm100_structured/): SM100 Structured Kernels - Blackwell matmul implementation. * [​`sm80`](./sm80/): Provides the CPU Hopper backend implementations for matmuls. * [​`sm90`](./sm90/): Provides the Nvidia Hopper backend implementations for matmuls. ## Modules * [​`profiler`](./profiler/): * [​`tile_scheduler`](./tile_scheduler/): * [​`tile_scheduler_splitk`](./tile_scheduler_splitk/): ## Functions * [​`matmul_kernel`](./matmul_kernel): Matrix Multiplication using shared memory. This version loads blocks of size tile\_size x tile\_size from A and B and updates a tile\_size x tile\_size in C. The thread block should have shape (tile\_size, tile\_size, 1). Each thread is mapped one element in C. The grid should have shape (N/tile\_size, M/tile\_size, 1). N is the first dimension for coalesced access. * [​`matmul_kernel_naive`](./matmul_kernel_naive): * [​`multistage_gemm`](./multistage_gemm): * [​`split_k_reduce`](./split_k_reduce):
--- ## matmul_kernel
`matmul_kernel[c_type: DType, a_type: DType, b_type: DType, tile_size: Int, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type]()](c_ptr: LegacyUnsafePointer[Scalar[c_type]], a_ptr: LegacyUnsafePointer[Scalar[a_type]], b_ptr: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)` Matrix Multiplication using shared memory. This version loads blocks of size tile\_size x tile\_size from A and B and updates a tile\_size x tile\_size in C. The thread block should have shape (tile\_size, tile\_size, 1). Each thread is mapped one element in C. The grid should have shape (N/tile\_size, M/tile\_size, 1). N is the first dimension for coalesced access.
--- ## matmul_kernel_naive
`matmul_kernel_naive[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, BLOCK_DIM: Int, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[c_type]()](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## multistage_gemm
`multistage_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)` `multistage_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b: NDBuffer[b_type, 2, origin, b_shape], runtime_config: MatmulConfig[a_type, b_type, c_type, transpose_b], ctx: DeviceContext)`
--- ## BlackwellProfileWarp
`struct BlackwellProfileWarp[load_warps: UInt32, mma_warps: UInt32, scheduler_warps: UInt32, epilogue_warps: UInt32, max_entries_per_warp: UInt32, //, WorkspaceManager: BlackwellWarpProfilingWorkspaceManager[load_warps, mma_warps, scheduler_warps, epilogue_warps, max_entries_per_warp], warp_role: UInt32 = 0]` This struct calculates execution time for a warp/s, and writes a single entry to the workspace. ## Fields * ​timeline (`Tuple[UInt64, UInt64]`): * ​workspace (`Span[UInt64, MutAnyOrigin]`): * ​entry\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `enable_profiling` `comptime enable_profiling = (max_entries_per_warp > 0)` ## Methods ### `__init__` `__init__(out self, workspace: Span[UInt64, MutAnyOrigin], entry_idx: UInt32)` ### `__enter__` `__enter__(mut self)` ### `__exit__` `__exit__(mut self)`
--- ## BlackwellWarpProfilingWorkspaceManager
`@register_passable(trivial)` `struct BlackwellWarpProfilingWorkspaceManager[load_warps: UInt32, mma_warps: UInt32, scheduler_warps: UInt32, epilogue_warps: UInt32, max_entries_per_warp: UInt32]` This struct manages the profiling workspace. The workspaces consists of equal sized chunks, the total number of which is equal to the total number of active SMs. Each SM chunk consists of sequences of entries, with a maximum number of entries per warp role. Template Parameters: load\_warps: Number of warps specialized for load operations mma\_warps: Number of warps specialized for matrix multiply-accumulate operations scheduler\_warps: Number of warps specialized for scheduling operations epilogue\_warps: Number of warps specialized for epilogue operations max\_entries\_per\_warp: Maximum number of entries per warp (common across all warp roles) ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `entries_per_sm` `comptime entries_per_sm = max_entries_per_warp.__rmul__[DType.uint32, 1](BlackwellWarpProfilingWorkspaceManager[load_warps, mma_warps, scheduler_warps, epilogue_warps, max_entries_per_warp].total_warp_roles)` ### `header` `comptime header = "time_start,time_end,sm_id,block_idx_x,block_idx_y,role,entry_idx\n"` ### `sm_count` `comptime sm_count = B200.sm_count` ### `total_data_points` `comptime total_data_points = 7` ### `total_warp_roles` `comptime total_warp_roles = 4` ## Methods ### `get_workspace` `static get_workspace(ctx: DeviceContext) -> Span[UInt64, MutAnyOrigin]` **Returns:** [`Span`](/mojo/std/memory/span/Span) ### `write_to_workspace` `static write_to_workspace[warp_role: UInt32](sm_idx: UInt32, entry_idx: UInt32, workspace: Span[UInt64, MutAnyOrigin], timeline: Tuple[UInt64, UInt64])` ### `dump_workspace_as_csv` `static dump_workspace_as_csv(ctx: DeviceContext, workspace: Span[UInt64, MutAnyOrigin], filename: StringSlice[StaticConstantOrigin])`
--- ## profiler (Profiler)
## `comptime` values ### `MatmulProfileWarp` `comptime MatmulProfileWarp[warp_role: UInt32, max_entries_per_warp: UInt32] = BlackwellProfileWarp[BlackwellWarpProfilingWorkspaceManager[1, 1, 1, 4, max_entries_per_warp](), warp_role]` #### Parameters * ​warp\_role ([`UInt32`](/std/builtin/simd/#uint32)): * ​max\_entries\_per\_warp ([`UInt32`](/std/builtin/simd/#uint32)): ### `MatmulWarpSpecializationWorkSpaceManager` `comptime MatmulWarpSpecializationWorkSpaceManager[max_entries_per_warp: UInt32] = BlackwellWarpProfilingWorkspaceManager[1, 1, 1, 4, max_entries_per_warp]` #### Parameters * ​max\_entries\_per\_warp ([`UInt32`](/std/builtin/simd/#uint32)): ## Structs * [​`BlackwellProfileWarp`](./BlackwellProfileWarp): This struct calculates execution time for a warp/s, and writes a single entry to the workspace. * [​`BlackwellWarpProfilingWorkspaceManager`](./BlackwellWarpProfilingWorkspaceManager): This struct manages the profiling workspace. The workspaces consists of equal sized chunks, the total number of which is equal to the total number of active SMs. Each SM chunk consists of sequences of entries, with a maximum number of entries per warp role.
--- ## blockwise_fp8
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `smem_layout_3D` `comptime smem_layout_3D[layout: Layout] = Layout(IntTuple(IntTuple(1), layout.shape[0].owned_copy(), layout.shape[1].owned_copy(), Tuple[]()), IntTuple(IntTuple(0), layout.stride[0].owned_copy(), layout.stride[1].owned_copy(), Tuple[]()))` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`matmul_sm100_blockwise_scaled_fp8`](./matmul_sm100_blockwise_scaled_fp8): * [​`matmul_sm100_blockwise_scaled_fp8_1d2d_kernel`](./matmul_sm100_blockwise_scaled_fp8_1d2d_kernel): * [​`matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper`](./matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper):
--- ## matmul_sm100_blockwise_scaled_fp8
`matmul_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## matmul_sm100_blockwise_scaled_fp8_1d2d_kernel
`matmul_sm100_blockwise_scaled_fp8_1d2d_kernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_scales_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Scalar[DType.uint] = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])`
--- ## matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper
`matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_scales_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Scalar[DType.uint] = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])`
--- ## config (Config)
SM100 matmul configuration - re-exports from sm100\_structured. This module re-exports configuration types from sm100\_structured for backward compatibility. New code should import directly from sm100\_structured.config.
--- ## sm100 (Sm100)
Provides the Nvidia Blackwell backend implementations for matmuls. ## Modules * [​`blockwise_fp8`](./blockwise_fp8/): * [​`config`](./config/): SM100 matmul configuration - re-exports from sm100\_structured. * [​`matmul`](./matmul/): SM100 matmul - COMPATIBILITY LAYER for grouped\_matmul. * [​`pipeline`](./pipeline/): SM100 pipeline utilities - re-exports from sm100\_structured. * [​`tile_scheduler`](./tile_scheduler/): * [​`warp_specialized_blockwise_fp8`](./warp_specialized_blockwise_fp8/):
--- ## WarpRole (Matmul)
`@register_passable(trivial)` `struct WarpRole[has_scheduler: Bool = True]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole[has_scheduler](3)` ### `MainLoad` `comptime MainLoad = WarpRole[has_scheduler](5) if has_scheduler else WarpRole[has_scheduler](4)` ### `Mma` `comptime Mma = WarpRole[has_scheduler](6) if has_scheduler else WarpRole[has_scheduler](5)` ### `Scheduler` `comptime Scheduler = WarpRole[has_scheduler](4)` ## Methods ### `__eq__` `__eq__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_scheduler` `static is_scheduler() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## accum_arrive
`accum_arrive[cta_group: Int](mma_output_pipeline: ProducerConsumerPipeline[num_stages], mma_output_stage: UInt32)`
--- ## consumer_main_loop (Matmul)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), k_group_size: Int = 1](tmem_addr: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)`
--- ## f32_frag_to_smem
`f32_frag_to_smem[swizzle_mode: TensorMapSwizzle, stageN: Scalar[DType.uint]](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul (3)
SM100 matmul - COMPATIBILITY LAYER for grouped\_matmul. NOTE: This module is maintained for backward compatibility with grouped\_matmul implementations that depend on internal functions (WarpRole, consumer\_main\_loop, stsm\_helper, shared\_memory\_epilogue, register\_epilogue, accum\_arrive). For new code, use sm100\_structured directly: * Import configs from: linalg.matmul.gpu.sm100\_structured.config * Import matmul from: linalg.matmul.gpu.sm100\_structured.matmul ## `comptime` values ### `RLayout32Bits` `comptime RLayout32Bits[layout: Layout] = RuntimeLayout[layout, element_type=DType.uint32, linear_idx_type=DType.uint32]` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`WarpRole`](./WarpRole): ## Functions * [​`accum_arrive`](./accum_arrive): * [​`consumer_main_loop`](./consumer_main_loop): * [​`f32_frag_to_smem`](./f32_frag_to_smem): * [​`register_epilogue`](./register_epilogue): * [​`shared_memory_epilogue`](./shared_memory_epilogue): * [​`shared_memory_epilogue_transpose`](./shared_memory_epilogue_transpose): * [​`stsm_helper`](./stsm_helper):
--- ## register_epilogue
`register_epilogue[MMA_M: Int, data_paths: Int, num_stages: Int, bits: Int, stage: Int, stageN: Int, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: Int, epilogue_dtype: DType, frag_size: Int, repeats: Int, transpose_c: Bool, cta_group: Int, is_lower_frag_required: Bool](mut upper_frag_casted: SIMD[epilogue_dtype, frag_size], mut lower_frag_casted: SIMD[epilogue_dtype, frag_size], c_row: UInt32, c_col: UInt32, N: UInt32)`
--- ## shared_memory_epilogue
`shared_memory_epilogue[MMA_M: Scalar[DType.uint], data_paths: Scalar[DType.uint], num_stages: Scalar[DType.uint], stage: Scalar[DType.uint], stageN: Scalar[DType.uint], c_type: DType, shared_n: Scalar[DType.uint], simd_size: Scalar[DType.uint], c_smem_upper_layout: Layout, c_smem_lower_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: Int](M: UInt32, N: UInt32, c_col: Scalar[DType.uint], c_row: Scalar[DType.uint], c_smem_warp_tile_upper: LayoutTensor[c_type, c_smem_upper_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_smem_warp_tile_lower: LayoutTensor[c_type, c_smem_lower_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## shared_memory_epilogue_transpose
`shared_memory_epilogue_transpose[stage: Scalar[DType.uint], stageN: Scalar[DType.uint], c_type: DType, c_smem_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: Int, warp_dim: Int, MMA_M: Int, BN: Int, cta_group: Int](M: UInt32, N: UInt32, c_col: Scalar[DType.uint], c_row: Scalar[DType.uint], c_smem: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_i: Scalar[DType.uint], warp_j: Scalar[DType.uint])`
--- ## stsm_helper (Matmul)
`stsm_helper[swizzle: Swizzle, stageN: Scalar[DType.uint], transpose_c: Bool = False, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_offset: UInt32 = 0)`
--- ## pipeline (Pipeline)
SM100 pipeline utilities - re-exports from sm100\_structured. This module re-exports pipeline types from sm100\_structured for backward compatibility. New code should import directly from sm100\_structured.pipeline.
--- ## TileScheduler (Tile_scheduler)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8]` ## Fields * ​cluster\_dim (`StaticTuple[Int32, 3]`): * ​log\_cluster\_dim\_m (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_n (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_k (`FastDiv[DType.uint32]`): * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cluster_size` `comptime cluster_size = ((cluster_shape.__getitem__[3, DType.uint32, Int](0) * cluster_shape.__getitem__[3, DType.uint32, Int](1)) * cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_k` `comptime log_cluster_k = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_m` `comptime log_cluster_m = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](0))` ### `log_cluster_n` `comptime log_cluster_n = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](1))` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `work_info_from_clc_response` `static work_info_from_clc_response(result: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `work_info_from_cluster` `static work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, 3], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState)
--- ## WorkInfo (Tile_scheduler)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## tile_scheduler
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel
`blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_tile_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, a_scales_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], num_pipeline_stages: Int, cluster_shape: StaticTuple[Int32, 3]](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], cluster_dim: StaticTuple[Int32, 3], num_iters: Scalar[DType.uint], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], problem_shape: StaticTuple[Int32, 3])`
--- ## warp_specialized_blockwise_fp8
## Functions * [​`blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel`](./blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel): * [​`load_AB`](./load_AB): * [​`multi_stage_reg_epilogue`](./multi_stage_reg_epilogue): * [​`promote_accumulators`](./promote_accumulators): * [​`sm100_warp_specialized_blockwise_fp8`](./sm100_warp_specialized_blockwise_fp8):
--- ## load_AB (Warp_specialized_blockwise_fp8)
`load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_layout: Layout, b_layout: Layout, a_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, a_scales_smem_layout: Layout, num_pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], a_scales_smem: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[num_pipeline_stages], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: Scalar[DType.uint], elect_one_cta: Bool)`
--- ## multi_stage_reg_epilogue (Warp_specialized_blockwise_fp8)
`multi_stage_reg_epilogue[c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, accum_type: DType, accum_layout: Layout, /, *, c_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle](c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt], elect_one_warp: Bool)`
--- ## promote_accumulators (Warp_specialized_blockwise_fp8)
`promote_accumulators[pipeline_stages: Int, num_accum_pipeline_stages: Int, accum_type: DType, accum_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_scales_smem_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, CLUSTER_SIZE: Int32, is_lower_frag_required: Bool, num_output_warps: Int](b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], a_scales_smem_iter: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_output_pipeline: ProducerConsumerPipeline[num_accum_pipeline_stages], tmem_addr: UInt32, load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], work_tile_coord: Tuple[UInt, UInt], elect_one_warp: Bool, stage_stride_cols: Scalar[DType.uint], k_iter: Scalar[DType.uint], problem_shape: StaticTuple[Int32, 3])`
--- ## sm100_warp_specialized_blockwise_fp8
`sm100_warp_specialized_blockwise_fp8[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, a_scales_layout: Layout, b_scales_layout: Layout, a_scales_type: DType, b_scales_type: DType, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## blackwell_block_scaled_matmul_tma_umma_warp_specialized (Block_scaled_matmul)
`blackwell_block_scaled_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: Optional[UInt32] = None](c_tensor: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_tensor: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_tensor: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales_tensor: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], b_scales_tensor: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext, alpha: Float32 = 1)` Launch block-scaled FP8 matmul kernel on SM100. Computes C = scale(A) @ scale(B) where A and B are FP8 matrices with per-block scaling factors following MXFP8 conventions. When config.AB\_swapped is True, internally swaps A and B operands (along with their scale factors) and transposes the output for better performance when M is small. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output element type. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Output tensor layout. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): A matrix element type (FP8). * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A matrix layout. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): B matrix element type (FP8). * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): B matrix layout. * ​sfa\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): A scaling factor type (F8-UE8M0). * ​sfa\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A scaling factor layout. * ​sfb\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): B scaling factor type (F8-UE8M0). * ​sfb\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): B scaling factor layout. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether B is transposed (must be True). * ​config ([`BlockScaledMatmulConfig`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/config/BlockScaledMatmulConfig)): Block-scaled matmul configuration. * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional epilogue lambda. * ​register\_based\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use register-based epilogue. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): Programmatic dependent launch level. * ​max\_profiled\_tiles\_per\_SM ([`Optional`](/mojo/std/collections/optional/Optional)): Optional profiling tile count. **Args:** * ​c\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor. * ​a\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A matrix tensor. * ​b\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): B matrix tensor. * ​a\_scales\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A scaling factors. * ​b\_scales\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): B scaling factors. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel launch. * ​alpha ([`Float32`](/mojo/std/builtin/simd/#float32)): Tensor scale factor (scalar). **Raises:** If configuration constraints are violated.
--- ## block_scaled_matmul (Block_scaled_matmul)
CPU entry points for block-scaled SM100 matmul. Creates TMA descriptors for A, B, C and scaling factors (SFA, SFB), then launches the warp-specialized kernel. ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`blackwell_block_scaled_matmul_tma_umma_warp_specialized`](./blackwell_block_scaled_matmul_tma_umma_warp_specialized): Launch block-scaled FP8 matmul kernel on SM100.
--- ## BlackwellBlockScaledMatmulKernel
`struct BlackwellBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]` Block-scaled matmul kernel V3 - ported from working legacy kernel. This struct provides the structured interface while internally using the proven legacy kernel logic. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_smem_layout.size() * size_of[a_type]())` ### `a_internal_layout` `comptime a_internal_layout = Layout[Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]]](Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)], ComptimeInt[8]](Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8)](), Idx[8]())), Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]](Idx[(128 // size_of[a_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]](Idx[(128 // size_of[a_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[1].value()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = DType.float32` ### `b_expected_bytes` `comptime b_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_smem_layout.size() * size_of[b_type]())` ### `b_internal_layout` `comptime b_internal_layout = Layout[Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]]](Coord[ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)], ComptimeInt[8]](Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8)](), Idx[8]())), Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]](Idx[(128 // size_of[b_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]](Idx[(128 // size_of[b_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[1].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputN)` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]` ### `input_expected_bytes` `comptime input_expected_bytes = ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_expected_bytes)) * config)` ### `InputTilePipeline` `comptime InputTilePipeline = InputTilePipeline[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `Scheduler` `comptime Scheduler = TileScheduler[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]` ### `SCHEDULER_THREADS` `comptime SCHEDULER_THREADS = WARP_SIZE` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE = (4 * config)` ### `sfa_expected_bytes` `comptime sfa_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_smem_layout.size() * size_of[sfa_dtype]())` ### `SFA_NUM_COLS` `comptime SFA_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 32))` ### `sfa_smem_layout` `comptime sfa_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `sfb_expected_bytes` `comptime sfb_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_smem_layout.size() * size_of[sfb_dtype]())` ### `SFB_NUM_COLS` `comptime SFB_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // 32))` ### `sfb_smem_layout` `comptime sfb_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SmemType` `comptime SmemType = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N` ### `TilePayload` `comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.BK, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.BK, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.SFA_DIM0, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.SFA_DIM1, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.SFB_DIM0, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.SFB_DIM1, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, config.num_output_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, 4, batched=True]` ### `TMA_LOAD_THREADS` `comptime TMA_LOAD_THREADS = WARP_SIZE` ### `Tmem` `comptime Tmem = TmemAllocation[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ## Methods ### `load_input_tiles` `static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)` Load A, B, SFA, SFB tiles using TMA with InputProducerStage. This method uses the structured ProducerStage pattern from matmul\_kernels.mojo, with tiles and barrier encapsulated in the stage. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for A matrix. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for B matrix. * ​sfa\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for A scaling factors. * ​sfb\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for B scaling factors. * ​tiles ([`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage)): ProducerStage context with encapsulated tile access. * ​peer\_cta\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (rank\_n, rank\_m, peer\_m\_rank) for peer CTA slicing. * ​work\_tile\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (m, n, k\_start) coordinates of the work tile. * ​a\_multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for A tiles. * ​b\_multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for B tiles. * ​iter\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): K iteration index (base index for k\_group). * ​elect\_one\_cta ([`Bool`](/mojo/std/builtin/bool/Bool)): True if this CTA should call expect\_bytes. ### `mma` `static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, iter_idx: UInt32, k_start: UInt32)` Execute MMA operations using InputConsumerStage. This method uses the structured ConsumerStage pattern from matmul\_kernels.mojo, with tiles and barrier encapsulated in the stage. **Args:** * ​tiles ([`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage)): ConsumerStage context with encapsulated tile access. * ​mma\_op ([`MmaOpSM100_BlockScaled_SS`](/mojo/kernels/linalg/arch/sm100/mma/MmaOpSM100_BlockScaled_SS)): Block-scaled MMA operation instance. * ​tmem\_addr ([`UInt32`](/mojo/std/builtin/simd/#uint32)): TMEM address for accumulators. * ​sfa\_tmem ([`UInt32`](/mojo/std/builtin/simd/#uint32)): TMEM base address for A scaling factors. * ​sfb\_tmem ([`UInt32`](/mojo/std/builtin/simd/#uint32)): TMEM base address for B scaling factors. * ​iter\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): K iteration index. * ​k\_start ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Starting K iteration (for init\_c determination). ### `epilogue` `static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], stage: OutputStage[config.num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32)` Execute epilogue to store accumulated results to global memory. Uses TileWriter which encapsulates: * TmemArrayType.load\_fragments() for TMEM load * AccumBarrier.arrive() for barrier signaling * TMEMToSMemWriter.write\_fragments() for SMEM write * 3D TMA store (M, N, Batch coordinates) * tma\_wait\_pipelined() for TMA wait Barrier synchronization (wait/step) is handled by caller via consumer() context. **Args:** * ​c\_tiles (`SMemTileArray2DRowMajor`): SMEM tile array for C output. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for C matrix. * ​stage (`OutputStage`): OutputStage from consumer() context with pipeline, index, and TMEM. * ​work\_tile\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (m, n, k\_start) coordinates. * ​M ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Problem M dimension. * ​N ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Problem N dimension. * ​alpha ([`Float32`](/mojo/std/builtin/simd/#float32)): Tensor scale factor (scalar). ### `validate_config` `static validate_config()` Validate configuration constraints at compile time. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], alpha: Float32, cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])` Kernel entry point - ported from legacy kernel.
--- ## block_scaled_matmul_kernel
Block-scaled SM100 matmul kernel - Structured kernel using tile pipelines. Uses patterns from matmul\_kernels.mojo with typed SMEM accessors and context manager-based pipeline synchronization for MXFP8 and NVFP4 block-scaled matrix multiplication. Architecture: * Uses Self.SmemType (BlockScaledSmem) with typed tile/barrier accessors * Uses Self.InputTilePipeline (BlockScaledTilePipeline) for producer/consumer sync * Load warp: with input\_pipeline.producer() as stage -> Self.load\_input\_tiles() * MMA warp: with input\_pipeline.consumer() as stage -> Self.mma() * Epilogue warp: Uses structured building blocks from epilogue\_components.mojo Epilogue Building Blocks (from epilogue\_components.mojo): * TmemArrayType / load\_fragments() for TMEM load * AccumBarrier.arrive() for barrier signaling * TMEMToSMemWriter.write\_fragments() for SMEM write * tma\_wait\_pipelined() for TMA wait * TMA store remains inline (3D batch coordinates) Key structured patterns: * Context manager pattern for pipeline synchronization * ProducerStage/ConsumerStage encapsulate tiles and barriers * stage.get\_tiles(j) returns (a, b, sfa, sfb) tuple * Automatic wait/step in context manager **enter**/**exit** ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`BlackwellBlockScaledMatmulKernel`](./BlackwellBlockScaledMatmulKernel): Block-scaled matmul kernel V3 - ported from working legacy kernel.
--- ## BlockScaledSmem
`struct BlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]` SMEM struct containing A/B tiles, scaling factors, C output, and barriers. ## Fields * ​tiles (`BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles`): * ​pipelines (`BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout` ### `ATileArray` `comptime ATileArray = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `Pipelines` `comptime Pipelines = SmemPipelineBundle[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]` ### `SF_BK` `comptime SF_BK = sf_bk[config]()` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()` ### `SFA_DIM0` `comptime SFA_DIM0 = sfa_dim0[config]()` ### `SFA_DIM1` `comptime SFA_DIM1 = sfa_dim1[config]()` ### `sfa_smem_layout` `comptime sfa_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFATileArray` `comptime SFATileArray = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray` ### `SFB_DIM0` `comptime SFB_DIM0 = sfb_dim0[config]()` ### `SFB_DIM1` `comptime SFB_DIM1 = sfb_dim1[config]()` ### `sfb_smem_layout` `comptime sfb_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFBTileArray` `comptime SFBTileArray = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray` ### `Tiles` `comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `BlockScaledSmem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `BlockScaledSmem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray` Get C tile array accessor (TileTensor-based). **Returns:** `BlockScaledSmem` ### `sfa_tiles` `sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray` Get SFA tile array accessor (TileTensor-based). **Returns:** `BlockScaledSmem` ### `sfb_tiles` `sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray` Get SFB tile array accessor (TileTensor-based). **Returns:** `BlockScaledSmem` ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Total size of A+B tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `sf_pipeline_size` `static sf_pipeline_size() -> Int` Total size of SFA+SFB tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Size of C tiles for all output stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+SFA+SFB+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## block_scaled_smem
Shared memory layout for block-scaled SM100 matmul. Extends standard SMEM with scaling factor tile storage (SFA, SFB) following MXFP8 layout conventions. Also includes all pipeline barriers and TMEM state. ## Structs * [​`BlockScaledSmem`](./BlockScaledSmem): SMEM struct containing A/B tiles, scaling factors, C output, and barriers. ## Functions * [​`sf_bk`](./sf_bk): Compute SF\_BK from config. * [​`sf_k_group_size`](./sf_k_group_size): Compute SF\_K\_GROUP\_SIZE from config. * [​`sfa_dim0`](./sfa_dim0): Compute SFA first dimension from config. * [​`sfa_dim1`](./sfa_dim1): Compute SFA second dimension from config. * [​`sfb_dim0`](./sfb_dim0): Compute SFB first dimension from config. * [​`sfb_dim1`](./sfb_dim1): Compute SFB second dimension from config.
--- ## sf_bk
`sf_bk[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SF\_BK from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## sf_k_group_size
`sf_k_group_size[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SF\_K\_GROUP\_SIZE from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## sfa_dim0
`sfa_dim0[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SFA first dimension from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## sfa_dim1
`sfa_dim1[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SFA second dimension from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## sfb_dim0
`sfb_dim0[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SFB first dimension from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## sfb_dim1
`sfb_dim1[config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Compute SFB second dimension from config. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## block_scaled
Block-scaled matmul kernel for SM100. ## Modules * [​`block_scaled_matmul`](./block_scaled_matmul/): CPU entry points for block-scaled SM100 matmul. * [​`block_scaled_matmul_kernel`](./block_scaled_matmul_kernel/): Block-scaled SM100 matmul kernel - Structured kernel using tile pipelines. * [​`block_scaled_smem`](./block_scaled_smem/): Shared memory layout for block-scaled SM100 matmul.
--- ## BlockwiseFP8Accumulator
`struct BlockwiseFP8Accumulator[accum_type: DType, accum_layout: Layout, is_lower_required: Bool, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_size: Int]` Register-based accumulator for blockwise FP8 matmul. Manages upper and lower fragment tiles in registers for per-K accumulation. Unlike TMEM-based accumulation, this allows scaling in CUDA cores. ## Parameters * ​accum\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Accumulator data type (typically float32). * ​accum\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): 2D layout (num\_stages, num\_elements) for register tiles. * ​is\_lower\_required ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether lower fragment is needed (based on cta\_group/MMA\_M). * ​block\_tile\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): Block tile dimensions (BM, BN, BK). * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): MMA operation dimensions (MMA\_M, MMA\_N, MMA\_K). * ​cluster\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of CTAs in the cluster. ## Fields * ​upper (`BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].UpperTile`): * ​lower (`BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].LowerTile`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `bits` `comptime bits = 256` ### `BK` `comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `data_paths` `comptime data_paths = 16` ### `fragment_size` `comptime fragment_size = (128 // WARP_SIZE)` ### `Fragments` `comptime Fragments = TmemFragments[accum_type, BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size, is_lower_required=is_lower_required]` ### `LowerTile` `comptime LowerTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_elements` `comptime num_elements = accum_layout.shape[1].value()` ### `num_elements_per_load` `comptime num_elements_per_load = 8` ### `num_stages` `comptime num_stages = accum_layout.shape[0].value()` ### `rep_frag_size` `comptime rep_frag_size = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)` ### `repeats` `comptime repeats = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements // BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)` ### `stageN` `comptime stageN = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * 8)` ### `UpperTile` `comptime UpperTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ## Methods ### `__init__` `__init__(out self)` Create accumulator with zero-initialized register tiles. ### `promote` `promote[num_pipeline_stages: Int, num_accum_pipeline_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int, b_scales_dtype: DType, b_scales_layout: Layout, a_scales_dtype: DType, a_scales_dim0: Int, a_scales_dim1: Int](mut self, b_scales: LayoutTensor[b_scales_dtype, b_scales_layout, MutAnyOrigin], a_scales_tiles: SMemTileArray2DRowMajor[a_scales_dtype, a_scales_dim0, a_scales_dim1, num_pipeline_stages], epi_stage: EpilogueKStage[num_accum_pipeline_stages, stage_stride_cols, cta_group, num_input_stages], work_tile_coord: Tuple[UInt, UInt], k_iter: Scalar[DType.uint], problem_shape: StaticTuple[Int32, 3])` Load partial from TMEM, apply scales, accumulate into registers. Core blockwise FP8 scaling: loads MMA partial from TMEM, reads A-scale from SMEM and B-scale from global memory, applies scaling, and accumulates into register tiles. Called within `with epi_ctx.per_k_stage(input_pipeline) as epi_stage:`.
--- ## get_accumulator_layout
`get_accumulator_layout[*, c_smem_dim1: Int, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int]() -> Layout` Compute the register accumulator layout for blockwise FP8. Returns a 2D layout (num\_stages, num\_elements) for the register tiles. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## blockwise_fp8_accumulator
Register-based accumulator for blockwise FP8 matmul. Unlike standard SM100 matmul which accumulates directly in TMEM, blockwise FP8 requires per-K-iteration scaling in CUDA cores: ``` for k in K_iterations: partial = TMEM load (MMA result) scaled = partial * a_scale * b_scale accum += scaled # in registers result = accum # write to SMEM → GMEM ``` ## Structs * [​`BlockwiseFP8Accumulator`](./BlockwiseFP8Accumulator): Register-based accumulator for blockwise FP8 matmul. ## Functions * [​`get_accumulator_layout`](./get_accumulator_layout): Compute the register accumulator layout for blockwise FP8. * [​`is_lower_fragment_required`](./is_lower_fragment_required): Determine if lower TMEM fragment is needed based on config.
--- ## is_lower_fragment_required
`is_lower_fragment_required[cta_group: Int, block_tile_shape: IndexList[3]]() -> Bool` Determine if lower TMEM fragment is needed based on config. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## blockwise_fp8_matmul
`blockwise_fp8_matmul[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, a_scales_layout: Layout, b_scales_layout: Layout, a_scales_type: DType, b_scales_type: DType, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` Launch blockwise FP8 matmul kernel. Environment: USE\_LEGACY\_BLOCKWISE\_FP8: If True, use legacy kernel instead of structured. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix (M x N). * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix A (M x K), FP8. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix B (K x N or N x K if transposed), FP8. * ​a\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scaling factors for A (M x ceil(K/128)), FP32. * ​b\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scaling factors for B (ceil(N/128) x ceil(K/128)), FP32. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel launch.
--- ## blockwise_fp8_matmul (Blockwise_fp8_matmul)
CPU entry points for blockwise FP8 SM100 matmul. Creates TMA descriptors for A, B, C and A-scales, then launches the warp-specialized blockwise FP8 kernel with register-based accumulation. ## Functions * [​`blockwise_fp8_matmul`](./blockwise_fp8_matmul): Launch blockwise FP8 matmul kernel.
--- ## BlackwellBlockwiseFP8MatmulKernel
`struct BlackwellBlockwiseFP8MatmulKernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, a_scales_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1)]` Blockwise FP8 matmul kernel with register-based accumulation. This kernel implements per-K-iteration scaling in CUDA cores: 1. Load warp: TMA loads A, B, A-scales to SMEM 2. MMA warp: Standard MMA (partial to TMEM) 3. Epilogue warp: TMEM read → scale → register accumulate → output ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].a_smem_layout.size() * size_of[a_type]())` ### `a_scales_expected_bytes` `comptime a_scales_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].a_scales_smem_layout.size() * size_of[a_scales_type]())` ### `a_scales_smem_layout` `comptime a_scales_smem_layout = Layout.row_major(1, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BM)` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BK, config.a_swizzle]()` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[0].value()` ### `accum_layout` `comptime accum_layout = get_accumulator_layout[c_smem_dim1=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].OutputN, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = DType.float32` ### `AccumTensor` `comptime AccumTensor = TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `Accumulator` `comptime Accumulator = BlockwiseFP8Accumulator[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].accum_layout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].is_lower_required, config.block_tile_shape, config.mma_shape, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE]` ### `AScalesLoaderType` `comptime AScalesLoaderType = ScalesTileLoader[?, ?, ?, ?, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `AScalesTileLT` `comptime AScalesTileLT = LayoutTensor[a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.a_scales_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `ATileLoaderType` `comptime ATileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `ATileLT` `comptime ATileLT = LayoutTensor[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.a_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `b_expected_bytes` `comptime b_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].b_smem_layout.size() * size_of[b_type]())` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].BK, config.b_swizzle]()` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[0].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileLoaderType` `comptime BTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `BTileLT` `comptime BTileLT = LayoutTensor[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.b_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].OutputN)` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS + (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE * ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)))` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_M * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]` ### `EpilogueHandle` `comptime EpilogueHandle = EpilogueWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]` ### `input_expected_bytes` `comptime input_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group * ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].a_expected_bytes + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].b_expected_bytes) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].a_scales_expected_bytes))` ### `InputTilePipeline` `comptime InputTilePipeline = InputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `is_lower_required` `comptime is_lower_required = is_lower_fragment_required[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, config.block_tile_shape]()` ### `max_tmem_cols` `comptime max_tmem_cols = 512` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]` ### `MmaHandle` `comptime MmaHandle = MmaWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = (((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `Scheduler` `comptime Scheduler = TileScheduler[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]` ### `SCHEDULER_THREADS` `comptime SCHEDULER_THREADS = WARP_SIZE` ### `SmemType` `comptime SmemType = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = (512 // config)` ### `TilePayload` `comptime TilePayload = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.BK, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.BK, 1, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = BlockwiseFP8TileWriter[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].OutputN, DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].accum_layout, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, is_lower_frag_required=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].is_lower_required, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, num_output_stages=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_output_stages, num_output_warps=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].num_output_warps, c_swizzle=config.c_swizzle]` ### `TMA_LOAD_THREADS` `comptime TMA_LOAD_THREADS = WARP_SIZE` ### `Tmem` `comptime Tmem = TmemAllocation[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ### `tmem_accum_layout` `comptime tmem_accum_layout = Layout.row_major(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].MMA_N)` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group]` ## Methods ### `load_input_tiles` `static load_input_tiles[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_scales_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoaderTMA[a_tma_origin, a_type, a_layout, a_desc_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group], b_loader: TileLoaderTMA[b_tma_origin, b_type, b_layout, b_desc_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group], a_scales_loader: ScalesTileLoader[a_scales_tma_origin, a_scales_type, a_scales_layout, a_scales_desc_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group], tiles: InputProducerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], iter_idx: Scalar[DType.uint], elect_one_cta: Bool)` Load A, B, and A-scales tiles using TMA. **Args:** * ​a\_loader ([`TileLoaderTMA`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_loader/TileLoaderTMA)): TileLoaderTMA for A matrix. * ​b\_loader ([`TileLoaderTMA`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_loader/TileLoaderTMA)): TileLoaderTMA for B matrix. * ​a\_scales\_loader ([`ScalesTileLoader`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_loader/ScalesTileLoader)): ScalesTileLoader for A-scales. * ​tiles ([`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage)): InputProducerStage context with encapsulated tile access. * ​peer\_cta\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Peer CTA coordinates for multicast. * ​work\_tile\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Current work tile M/N coordinates. * ​iter\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K iteration index. * ​elect\_one\_cta ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether this is the elected CTA in the cluster. ### `mma` `static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], accum_tensor: TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, c_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, c_desc_layout, a_scales_desc_layout, transpose_b, config, cluster_shape].cta_group])` Execute standard MMA operations (partial results to TMEM). For blockwise FP8, each K iteration writes a fresh partial to TMEM. The epilogue accumulates across K in registers, not TMEM. Therefore init\_c is always True (unlike standard matmul). **Args:** * ​tiles ([`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage)): Input consumer stage with A, B, A-scales tiles. * ​mma\_op ([`MmaOpSM100_SS`](/mojo/kernels/linalg/arch/sm100/mma/MmaOpSM100_SS)): The MMA operator. * ​accum\_tensor ([`TmemTensor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemTensor)): Typed TMEM tensor view for the accumulator stage. ### `validate_config` `static validate_config()` Validate configuration constraints at compile time. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], cluster_dim: StaticTuple[Int32, 3], num_iters: Scalar[DType.uint], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], problem_shape: StaticTuple[Int32, 3])` Kernel entry point for blockwise FP8 matmul.
--- ## blockwise_fp8_matmul_kernel
Blockwise FP8 SM100 matmul kernel - Structured kernel with register accumulation. Unlike standard SM100 matmul which accumulates in TMEM, blockwise FP8 applies scaling factors per-K-iteration in CUDA cores, accumulating in registers. Architecture: * Load warp: TMA loads A, B, and A-scales into SMEM * MMA warp: Standard MMA operations (partial results to TMEM) * Epilogue warp: Per-K TMEM read → scale → register accumulate → final output Key differences from standard/block-scaled kernels: * Uses MmaOpSM100\_SS (not block-scaled MMA) * A-scales loaded via TMA, B-scales from global memory * BlockwiseFP8Accumulator for register-based K-loop accumulation * BlockwiseFP8TileWriter for final register → SMEM → GMEM flow ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`BlackwellBlockwiseFP8MatmulKernel`](./BlackwellBlockwiseFP8MatmulKernel): Blockwise FP8 matmul kernel with register-based accumulation.
--- ## BlockwiseFP8TileWriter
`struct BlockwiseFP8TileWriter[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, accum_type: DType, accum_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_stages: Int, num_output_warps: Scalar[DType.uint], c_swizzle: TensorMapSwizzle]` Write register accumulators to GMEM via SMEM and TMA. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `bits` `comptime bits = 256` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)` ### `CTileArray` `comptime CTileArray = SMemTileArray[c_type, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].c_smem_layout, num_output_stages, 128]` ### `CTileArrayTT` `comptime CTileArrayTT = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]` ### `data_paths` `comptime data_paths = 16` ### `fragment_size` `comptime fragment_size = (128 // WARP_SIZE)` ### `fragments_per_stage` `comptime fragments_per_stage = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size * BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats)` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_elements` `comptime num_elements = accum_layout.shape[1].value()` ### `num_elements_per_load` `comptime num_elements_per_load = 8` ### `num_stages` `comptime num_stages = accum_layout.shape[0].value()` ### `repeats` `comptime repeats = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].num_elements // BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size)` ### `SMEMWriter` `comptime SMEMWriter = TMEMToSMemWriter[c_type, accum_type, c_smem_dim0, c_smem_dim1, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BM, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BN, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_M, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_N, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].stageN, cta_group, Int.__init__[Scalar[DType.uint]](num_output_warps), c_swizzle]` ### `stageN` `comptime stageN = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats * 8)` ## Methods ### `write` `static write[c_layout: Layout, c_desc_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt])` Write accumulated register tiles to GMEM via double-buffered SMEM. ### `write_absolute_with_bounds_check` `static write_absolute_with_bounds_check[c_tensor_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin])` Write accumulated register tiles to GMEM with bounds checking. For 1D-1D grouped kernels where M coordinate is absolute in contiguous token space. Applies expert\_scale to fragments before store. Handles partial tiles that cross expert boundaries by using element-by-element stores for rows that would exceed m\_end. **Args:** * ​accum ([`BlockwiseFP8Accumulator`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/blockwise_fp8/blockwise_fp8_accumulator/BlockwiseFP8Accumulator)): Blockwise FP8 accumulator with upper/lower register tiles. * ​c\_tiles ([`SMemTileArray2DRowMajor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2DRowMajor)): SMEM tile array for C output (TileTensor-based). * ​m\_abs ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Absolute M coordinate (start of tile in token space). * ​n\_abs ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Absolute N coordinate (start of tile). * ​m\_end ([`UInt32`](/mojo/std/builtin/simd/#uint32)): End offset for bounds checking (exclusive). * ​expert\_scale ([`Float32`](/mojo/std/builtin/simd/#float32)): Per-expert output scaling factor. * ​c\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): C tensor in GMEM (for bounds-checked stores).
--- ## blockwise_fp8_output_writer
Output writer for blockwise FP8 SM100 matmul. Handles Register → SMEM → GMEM (via TMA) flow. Unlike standard matmul which reads from TMEM, blockwise FP8 accumulators are already in registers. Supports two write modes: * write(): TMA store for standard non-grouped matmul * write\_absolute\_with\_bounds\_check(): Element-by-element store for 1D2D grouped matmul with expert boundary bounds checking ## Structs * [​`BlockwiseFP8TileWriter`](./BlockwiseFP8TileWriter): Write register accumulators to GMEM via SMEM and TMA.
--- ## BlockwiseFP8Smem
`struct BlockwiseFP8Smem[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]]` SMEM struct for blockwise FP8 matmul: A/B tiles, A-scales, C output, barriers. Key differences from BlockScaledSmem: * A-scales stored in SMEM (1D: 1 x BM per pipeline stage) * No B-scales in SMEM (read from global memory during epilogue) * Used with register-based accumulation pattern ## Fields * ​tiles (`BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles`): * ​pipelines (`BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Pipelines`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_scales_smem_layout` `comptime a_scales_smem_layout = Layout.row_major(1, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM)` ### `a_smem_layout` `comptime a_smem_layout = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.a_smem_layout` ### `AScalesTileArray` `comptime AScalesTileArray = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.AScalesTileArray` ### `ATileArray` `comptime ATileArray = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.CTileArray` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `Pipelines` `comptime Pipelines = SmemPipelineBundle[BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_group_pipeline_stages, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_accum_pipeline_stages, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_clc_pipeline_stages, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, 1, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages]]` ### `Tiles` `comptime Tiles = BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputN, 1, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages, BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_output_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8Smem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8Smem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].CTileArray` Get C tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8Smem` ### `a_scales_tiles` `a_scales_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].AScalesTileArray` Get A-scales tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8Smem` ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Total size of A+B tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `a_scales_pipeline_size` `static a_scales_pipeline_size() -> Int` Total size of A-scales tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Size of C tiles for all output stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+A-scales+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## blockwise_fp8_smem
Shared memory layout for blockwise FP8 SM100 matmul. This module provides the SMEM struct for blockwise FP8 matmul kernels where: * A-scales are loaded via TMA and stored in SMEM (1D: 1 x BM per stage) * B-scales are read directly from global memory (not stored in SMEM) * Scaling is applied post-MMA in CUDA cores, not within the MMA unit Unlike block-scaled matmul, blockwise FP8 uses register-based accumulation across K iterations, with scales applied per-iteration. ## Structs * [​`BlockwiseFP8Smem`](./BlockwiseFP8Smem): SMEM struct for blockwise FP8 matmul: A/B tiles, A-scales, C output, barriers.
--- ## blockwise_fp8 (Blockwise_fp8)
Blockwise FP8 matmul kernel for SM100. ## Modules * [​`blockwise_fp8_accumulator`](./blockwise_fp8_accumulator/): Register-based accumulator for blockwise FP8 matmul. * [​`blockwise_fp8_matmul`](./blockwise_fp8_matmul/): CPU entry points for blockwise FP8 SM100 matmul. * [​`blockwise_fp8_matmul_kernel`](./blockwise_fp8_matmul_kernel/): Blockwise FP8 SM100 matmul kernel - Structured kernel with register accumulation. * [​`blockwise_fp8_output_writer`](./blockwise_fp8_output_writer/): Output writer for blockwise FP8 SM100 matmul. * [​`blockwise_fp8_smem`](./blockwise_fp8_smem/): Shared memory layout for blockwise FP8 SM100 matmul.
--- ## grouped_matmul_1d2d_blockwise_fp8
`grouped_matmul_1d2d_blockwise_fp8[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, a_scales_type: DType, b_scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, transpose_b: Bool, //, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)` Launch grouped 1D-1D blockwise FP8 matmul kernel. This function sets up TMA descriptors and launches the kernel with the proper configuration for 1D-1D tensor layout with blockwise FP8 scaling. **Args:** * ​c\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor (total\_tokens, N). * ​a\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input A tensor (total\_tokens, K). * ​b\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Weight tensor B (num\_experts, N, K). * ​a\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scaling factors for A (K//128 x total\_tokens), FP32. * ​b\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scaling factors for B (num\_experts x N//128 x K//128), FP32. * ​a\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert offsets (num\_active\_experts + 1). * ​expert\_ids ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Active expert IDs (num\_active\_experts). * ​expert\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert output scaling (num\_experts). * ​num\_active\_experts ([`Int`](/mojo/std/builtin/int/Int)): Number of active experts. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context.
--- ## grouped_matmul_dynamic_scaled_fp8_1d2d
`grouped_matmul_dynamic_scaled_fp8_1d2d[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, a_scales_type: DType, b_scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, //, transpose_b: Bool = True](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)` Compatibility wrapper that matches the existing dispatch API. Creates the default config and calls the new structured kernel.
--- ## blockwise_fp8_1d2d_matmul
CPU entrypoint for grouped 1D-1D blockwise FP8 SM100 matmul. This module provides the public API for launching the grouped 1D-1D blockwise FP8 matmul kernel for Mixture of Experts (MoE) layers. Usage: grouped\_matmul\_1d2d\_blockwise\_fp8\[transpose\_b=True, config=config]\( c\_tensor, a\_tensor, b\_tensor, a\_scales, b\_scales, a\_offsets, expert\_ids, expert\_scales, num\_active\_experts, ctx, ) ## Functions * [​`grouped_matmul_1d2d_blockwise_fp8`](./grouped_matmul_1d2d_blockwise_fp8): Launch grouped 1D-1D blockwise FP8 matmul kernel. * [​`grouped_matmul_dynamic_scaled_fp8_1d2d`](./grouped_matmul_dynamic_scaled_fp8_1d2d): Compatibility wrapper that matches the existing dispatch API.
--- ## BlockwiseFP8_1D2DMatmulKernel
`struct BlockwiseFP8_1D2DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, b_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, c_device_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1)]` Blockwise FP8 1D2D matmul kernel with register-based accumulation. Combines blockwise FP8 scaling (per-K in CUDA cores) with 1D-1D offset-based work distribution for grouped GEMM in MoE layers. Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.a_smem_layout.size() * size_of[a_type]())` ### `a_scales_expected_bytes` `comptime a_scales_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.a_scales_smem_layout.size() * size_of[a_scales_type]())` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[0].value()` ### `accum_layout` `comptime accum_layout = get_accumulator_layout[c_smem_dim1=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].OutputN, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group]()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group * 128)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = DType.float32` ### `Accumulator` `comptime Accumulator = BlockwiseFP8Accumulator[DType.float32, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].accum_layout, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].is_lower_required, config.block_tile_shape, config.mma_shape, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].CLUSTER_SIZE]` ### `AScalesTileLT` `comptime AScalesTileLT = LayoutTensor[a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.a_scales_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `b_expected_bytes` `comptime b_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.b_smem_layout.size() * size_of[b_type]())` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[0].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].CLUSTER_M * BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].CLUSTER_N)` ### `cta_group` `comptime cta_group = config.cta_group` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].stage_stride_cols, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, 32, 128]` ### `input_expected_bytes` `comptime input_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group * ((BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].a_expected_bytes + BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].b_expected_bytes) + BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].a_scales_expected_bytes))` ### `InputTilePipelineType` `comptime InputTilePipelineType = InputTilePipeline[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `is_lower_required` `comptime is_lower_required = is_lower_fragment_required[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, config.block_tile_shape]()` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].stage_stride_cols, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, 32, 128]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = WarpRole1D1D.TOTAL_THREADS` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].stage_stride_cols, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group]` ### `SmemType` `comptime SmemType = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].MMA_N` ### `TilePayload` `comptime TilePayload = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.BM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.BK, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.BN, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.BK, 1, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.BM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = BlockwiseFP8TileWriter[c_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].OutputM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].OutputN, DType.float32, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].accum_layout, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, is_lower_frag_required=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].is_lower_required, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, num_output_stages=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].num_output_stages, num_output_warps=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].num_output_warps, c_swizzle=config.c_swizzle]` ### `Tmem` `comptime Tmem = TmemAllocation[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group]` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group]` ### `WorkIterator` `comptime WorkIterator = GroupedWorkIterator1D1D[offsets_layout, expert_ids_layout, expert_scales_layout, static_N, config.block_tile_shape, config.cluster_shape, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group]` ## Methods ### `validate_config` `static validate_config()` Compile-time validation of kernel configuration. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32)` Grouped 1D-1D blockwise FP8 GEMM kernel entry point. Uses grid-constant TMAs with offset-based addressing for 1D-1D layout. Accumulates in registers with per-K scaling in CUDA cores. ### `load_input_tiles` `static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], tiles: InputProducerStage[tiles_origin, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_ctx: GroupedWorkContext1D1D, iter_idx: Int, elect_one_cta: Bool)` Load A, B, and A-scales tiles using TMA. ### `mma` `static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, a_layout, b_layout, a_scales_layout, b_scales_layout, a_desc_layout, b_desc_layout, a_scales_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, c_device_layout, transpose_b, config, static_N, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32)` Execute standard MMA operations (partial results to TMEM). For blockwise FP8, each K iteration writes a fresh partial to TMEM. The epilogue accumulates across K in registers, not TMEM. Therefore init\_c is always True.
--- ## blockwise_fp8_1d2d_matmul_kernel
Blockwise FP8 1D2D SM100 matmul kernel. This kernel combines: * Accumulation pattern from blockwise\_fp8/ (register-based per-K scaling via BlockwiseFP8Accumulator, standard MMA, A-scales in SMEM, B-scales from GMEM) * 1D2D work distribution from grouped\_block\_scaled\_1d1d/ (GroupedWorkIterator1D1D, offset-based A tensor addressing, bounds-checked output, 3-warp specialization, SmemPipelineBundleNoClc) Architecture: * TMA warp: Loads A, B, A-scales tiles using grid-constant TMAs * MMA warp: Standard MMA (partial results to TMEM, init\_c=True every K iter) * Epilogue warps: Per-K TMEM read → scale → register accumulate → final output with bounds checking ## Structs * [​`BlockwiseFP8_1D2DMatmulKernel`](./BlockwiseFP8_1D2DMatmulKernel): Blockwise FP8 1D2D matmul kernel with register-based accumulation.
--- ## BlockwiseFP8_1D2DSmem
`struct BlockwiseFP8_1D2DSmem[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]]` SMEM struct for blockwise FP8 1D2D matmul: A/B tiles, A-scales, C output, barriers. Uses SmemPipelineBundleNoClc (no CLC scheduler) for 3-warp specialization. Otherwise identical to BlockwiseFP8Smem in tile storage layout. ## Fields * ​tiles (`BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles`): * ​pipelines (`BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Pipelines`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_scales_smem_layout` `comptime a_scales_smem_layout = Layout.row_major(1, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM)` ### `a_smem_layout` `comptime a_smem_layout = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.a_smem_layout` ### `AScalesTileArray` `comptime AScalesTileArray = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.AScalesTileArray` ### `ATileArray` `comptime ATileArray = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Tiles.CTileArray` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `Pipelines` `comptime Pipelines = SmemPipelineBundleNoClc[BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_group_pipeline_stages, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_accum_pipeline_stages, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, 1, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages]]` ### `Tiles` `comptime Tiles = BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BN, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BK, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].OutputN, 1, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BM, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_pipeline_stages, BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].num_output_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8_1D2DSmem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8_1D2DSmem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].CTileArray` Get C tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8_1D2DSmem` ### `a_scales_tiles` `a_scales_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].AScalesTileArray` Get A-scales tile array accessor (TileTensor-based). **Returns:** `BlockwiseFP8_1D2DSmem` ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Total size of A+B tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `a_scales_pipeline_size` `static a_scales_pipeline_size() -> Int` Total size of A-scales tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Size of C tiles for all output stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+A-scales+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## blockwise_fp8_1d2d_smem
Shared memory layout for blockwise FP8 1D2D SM100 matmul. This is a simplified SMEM structure for the 1D2D blockwise FP8 kernel that uses offset-based addressing (GroupedWorkIterator1D1D). Key differences from the standard BlockwiseFP8Smem: 1. No CLC pipeline storage - uses 3-warp specialization (no scheduler warp) 2. Uses SmemPipelineBundleNoClc instead of SmemPipelineBundle 3. Otherwise identical tile storage (A, B, C, A-scales) The 1D-1D layout uses: * A tensor: Contiguous (total\_tokens, K) with a\_offsets for per-group access * B tensor: Batched (num\_experts \* N, K) weights * C tensor: Contiguous (total\_tokens, N) output ## Structs * [​`BlockwiseFP8_1D2DSmem`](./BlockwiseFP8_1D2DSmem): SMEM struct for blockwise FP8 1D2D matmul: A/B tiles, A-scales, C output, barriers.
--- ## blockwise_fp8_1d2d
Blockwise FP8 1D2D grouped matmul kernel for SM100. This module provides a structured kernel implementation for grouped blockwise FP8 GEMM using the 1D-1D tensor layout with offset-based addressing. It combines: * Accumulation pattern from blockwise\_fp8/ (register-based per-K scaling) * 1D2D work distribution from grouped\_block\_scaled\_1d1d/ (offset-based A tensor addressing, bounds-checked output, 3-warp specialization) ## Modules * [​`blockwise_fp8_1d2d_matmul`](./blockwise_fp8_1d2d_matmul/): CPU entrypoint for grouped 1D-1D blockwise FP8 SM100 matmul. * [​`blockwise_fp8_1d2d_matmul_kernel`](./blockwise_fp8_1d2d_matmul_kernel/): Blockwise FP8 1D2D SM100 matmul kernel. * [​`blockwise_fp8_1d2d_smem`](./blockwise_fp8_1d2d_smem/): Shared memory layout for blockwise FP8 1D2D SM100 matmul.
--- ## heuristic_and_outliers_dispatch
`heuristic_and_outliers_dispatch[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## dispatch
## `comptime` values ### `DISPATCH_HIT` `comptime DISPATCH_HIT = 1` ### `DISPATCH_MISS` `comptime DISPATCH_MISS = 0` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`heuristic_and_outliers_dispatch`](./heuristic_and_outliers_dispatch): * [​`matmul_dispatch_sm100`](./matmul_dispatch_sm100): * [​`matmul_dispatch_sm100_bf16`](./matmul_dispatch_sm100_bf16): * [​`matmul_dispatch_sm100_fp8`](./matmul_dispatch_sm100_fp8):
--- ## matmul_dispatch_sm100
`matmul_dispatch_sm100[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_lambda_wrapper: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext)`
--- ## matmul_dispatch_sm100_bf16
`matmul_dispatch_sm100_bf16[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## matmul_dispatch_sm100_fp8
`matmul_dispatch_sm100_fp8[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## default (Default)
Default SM100 matmul kernel - Standard FP8/BF16 warp-specialized implementation. ## Modules * [​`dispatch`](./dispatch/): * [​`matmul`](./matmul/): SM100 Matmul CPU entry points - TMA setup and kernel launch wrappers. * [​`matmul_kernels`](./matmul_kernels/): SM100 Default Matmul Kernel - Standard FP8/BF16 warp-specialized kernel. * [​`tuning_configs`](./tuning_configs/):
--- ## blackwell_matmul_tma_umma_warp_specialized
`blackwell_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: Optional[UInt32] = None](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## matmul (4)
SM100 Matmul CPU entry points - TMA setup and kernel launch wrappers. This module contains the CPU-side code for SM100 matrix multiplication: * TMA descriptor creation * Kernel instantiation and launch via ctx.enqueue\_function All GPU code (kernel structs, runtime functions) is in matmul\_kernels.mojo. ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`blackwell_matmul_tma_umma_warp_specialized`](./blackwell_matmul_tma_umma_warp_specialized): * [​`matmul_sm100_fallback`](./matmul_sm100_fallback):
--- ## matmul_sm100_fallback
`matmul_sm100_fallback[a_layout: Layout, b_layout: Layout, c_layout: Layout, c_type: DType, a_type: DType, b_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## B200MatmulSmem
`struct B200MatmulSmem[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]]` Shared memory layout for B200 SM100 matrix multiplication kernel. This struct manages the shared memory allocation for: * Input tiles (A and B matrices) with multi-stage pipelining * Output tile (C matrix) for accumulation * Synchronization barriers for producer-consumer coordination * CLC (Cluster Launch Control) barriers and response storage * TMEM (Tensor Memory) address and deallocation barrier The memory is organized to support asynchronous TMA loads and efficient bank-conflict-free access patterns for tensor core operations. Type aliases are provided for tile types (ATile, BTile, CTile) to enable cleaner function signatures without verbose LayoutTensor declarations. ## Fields * ​input\_tiles (`B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].InputTiles`): * ​output\_tiles (`B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputTiles`): * ​pipelines (`B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].Pipelines`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].Layouts.a_smem_layout` ### `ATileArray` `comptime ATileArray = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].InputTiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].InputTiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputTiles.CTileArray` ### `InputTiles` `comptime InputTiles = StandardTileStorage[a_type, b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages]` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputTiles` `comptime OutputTiles = OutputTileStorage[c_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_output_stages]` ### `Pipelines` `comptime Pipelines = SmemPipelineBundle[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_accum_pipeline_stages, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages, StandardTilePayload[a_type, b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages]]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].ATileArray` **Returns:** `B200MatmulSmem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BTileArray` **Returns:** `B200MatmulSmem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].CTileArray` **Returns:** `B200MatmulSmem` ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Total size of A+B tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Size of C tiles for all output stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## BlackwellMatmulSM100FallbackKernel
`struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]` Simple fallback matmul kernel for SM100 (B200). This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations. Unlike the main BlackwellMatmulSM100Kernel, this uses: * Single warp approach (no warp specialization) * Basic barrier synchronization (no CLC scheduling) * Direct LayoutTensor output (no TMA for C) * Simpler pipeline with single buffer ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_size` `comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `ATile` `comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `b_size` `comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()` ### `BK` `comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTile` `comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `c_frag_size` `comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)` ### `max_tmem_cols` `comptime max_tmem_cols = 512` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_k_mmas` `comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)` ## Methods ### `validate_constraints` `static validate_constraints()` Validate compile-time constraints for this kernel configuration. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])` Run the fallback matmul kernel. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor C (LayoutTensor, not TMA). * ​num\_iters ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of K-dimension iterations.
--- ## BlackwellMatmulSM100Kernel
`struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]` Blackwell SM100 GEMM kernel with warp specialization. This struct unifies all parameters and derived types for the SM100 matmul kernel, providing: * Compile-time parameter validation * Centralized derived type computation * Factory methods for kernel components * Multiple kernel entry points (standard, split-k) The SM100 kernel uses: * Tensor Memory (TMEM) for MMA accumulators * Cluster Launch Control (CLC) for dynamic tile scheduling * Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps * Software pipelining for overlapping compute and memory operations ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout.size() * size_of[a_type]())` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[0].value()` ### `accum_layout` `comptime accum_layout = Layout.row_major(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N)` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type` ### `AccumTensor` `comptime AccumTensor = TmemTensor[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_layout, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `ATileLoaderType` `comptime ATileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `ATileLT` `comptime ATileLT = LayoutTensor[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `b_expected_bytes` `comptime b_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout.size() * size_of[b_type]())` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[0].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileLoaderType` `comptime BTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `BTileLT` `comptime BTileLT = LayoutTensor[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, ?, address_space=AddressSpace.SHARED, alignment=128]` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueConf` `comptime EpilogueConf = EpilogueConfig[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.c_smem_layout.shape[1].value(), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, False]` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]` ### `input_expected_bytes` `comptime input_expected_bytes = ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes)) * config)` ### `InputTilePipeline` `comptime InputTilePipeline = InputTilePipeline[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // config)` ### `num_k_mmas` `comptime num_k_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))` ### `num_n_mmas` `comptime num_n_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = (((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `Scheduler` `comptime Scheduler = TileScheduler[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]` ### `SCHEDULER_THREADS` `comptime SCHEDULER_THREADS = WARP_SIZE` ### `SmemType` `comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = (512 // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages)` ### `TilePayload` `comptime TilePayload = StandardTilePayload[a_type, b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_output_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue]` ### `TMA_LOAD_THREADS` `comptime TMA_LOAD_THREADS = WARP_SIZE` ### `Tmem` `comptime Tmem = TmemAllocation[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ## Methods ### `validate_constraints` `static validate_constraints()` Validate parameter constraints at compile time. ### `init_barriers` `static init_barriers(ctx: KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], input_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])` Initialize barriers and prefetch TMA descriptors. Called by elect\_one\_warp && elect\_one\_thread. ### `mma` `static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[config.num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], tiles: InputConsumerStage[tiles_origin, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)` Execute MMA operations for one pipeline stage. This is the core MMA function designed to be called within a consumer stage context: ``` with consumer.acquire() as tiles: Self.mma(stage.tmem, tiles, mma_op, ...) ``` **Args:** * ​tmem\_stage (`TmemStage`): TMEM stage for accumulators. * ​tiles ([`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage)): InputConsumerStage context with encapsulated tile access. * ​mma\_op ([`MmaOpSM100_SS`](/mojo/kernels/linalg/arch/sm100/mma/MmaOpSM100_SS)): The MMA operation instance. * ​elect\_one\_warp ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether this warp should execute. * ​iter\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): K iteration index. * ​k\_start ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Starting K iteration (for init\_c determination). ### `load_input_tiles` `static load_input_tiles[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoaderTMA[a_tma_origin, a_type, a_layout, a_desc_layout, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], b_loader: TileLoaderTMA[b_tma_origin, b_type, b_layout, b_desc_layout, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], tiles: InputProducerStage[tiles_origin, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Scalar[DType.uint], work_n_coord: Scalar[DType.uint], peer_cta_coord: Tuple[UInt, UInt, UInt], elect_one_cta: Bool)` Load k\_group\_size A and B tiles using TMA. Orchestrates the tile loading operation including: * expect\_bytes signaling * k-group iteration * Peer CTA slicing for 2-SM MMA **Args:** * ​a\_loader ([`TileLoaderTMA`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_loader/TileLoaderTMA)): TileLoaderTMA for A matrix. * ​b\_loader ([`TileLoaderTMA`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_loader/TileLoaderTMA)): TileLoaderTMA for B matrix. * ​tiles ([`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage)): InputProducerStage context with encapsulated tile access. * ​iter\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): K iteration index (base index). * ​work\_m\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): M coordinate of the output tile. * ​work\_n\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): N coordinate of the output tile. * ​peer\_cta\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Peer CTA coordinates (rank\_n, rank\_m, peer\_m\_rank). * ​elect\_one\_cta ([`Bool`](/mojo/std/builtin/bool/Bool)): True if this CTA should call expect\_bytes. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])` Main kernel entry point for SM100 matrix multiplication. ### `run_splitk` `static run_splitk[reduction_layout: Layout](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], reduction_tensor: LayoutTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: LegacyUnsafePointer[UInt8], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])` Split-K kernel entry point for better parallelism on small problems. Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix C. * ​reduction\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Workspace for partial results from each split. * ​lock\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Synchronization locks for reduction coordination. * ​cluster\_dim ([`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple)): Cluster dimensions. * ​mnk ([`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple)): Problem dimensions (M, N, K). * ​workspace ([`Span`](/mojo/std/memory/span/Span)): Workspace buffer for profiling/scheduling.
--- ## matmul_kernels
SM100 Default Matmul Kernel - Standard FP8/BF16 warp-specialized kernel. This module contains the default SM100 matmul kernel implementation: * B200MatmulSmem: Shared memory layout for the kernel * BlackwellMatmulSM100Kernel: Main kernel struct with run() and run\_splitk() * BlackwellMatmulSM100FallbackKernel: Simple fallback kernel Shared components (WarpRole, KernelContext) are in kernel\_common.mojo. Output pipeline (TileWriter, copy\_accum\_to\_gmem) is in output\_writer.mojo. Low-level epilogue components (TMAStoreExecutor, etc.) are in epilogue\_components.mojo. The kernel implements a warp-specialized architecture: * Scheduler warp: CLC-based tile scheduling * TMA Load warp: Async memory transfers * MMA warp: Tensor core operations with TMEM accumulators * Epilogue warps: Output from TMEM to GMEM via TileWriter ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`B200MatmulSmem`](./B200MatmulSmem): Shared memory layout for B200 SM100 matrix multiplication kernel. * [​`BlackwellMatmulSM100FallbackKernel`](./BlackwellMatmulSM100FallbackKernel): Simple fallback matmul kernel for SM100 (B200). * [​`BlackwellMatmulSM100Kernel`](./BlackwellMatmulSM100Kernel): Blackwell SM100 GEMM kernel with warp specialization.
--- ## TuningConfigSM100
`@register_passable(trivial)` `struct TuningConfigSM100` ## Fields * ​M (`Int`): * ​M\_end (`Int`): * ​N (`Int`): * ​K (`Int`): * ​mma\_shape (`IndexList[3]`): * ​block\_tile\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​block\_swizzle\_size (`UInt`): * ​rasterize\_order (`RasterOrder`): * ​cta\_group (`Int`): * ​swapAB (`Bool`): * ​k\_group\_size (`UInt`): * ​num\_accum\_pipeline\_stages (`UInt`): * ​num\_clc\_pipeline\_stages (`UInt`): * ​num\_split\_k (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(M: Int, N: Int, K: Int, mma_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[3], block_swizzle_size: Scalar[DType.uint], rasterize_order: RasterOrder, cta_group: Int = 2, swapAB: Bool = False, k_group_size: Scalar[DType.uint] = 1, num_accum_pipeline_stages: Scalar[DType.uint] = 2, num_clc_pipeline_stages: Scalar[DType.uint] = 2, num_split_k: Int = 1) -> Self` `__init__(M: Int, M_end: Int, N: Int, K: Int, mma_shape: IndexList[3], cta_group: Int, cluster_shape: IndexList[3], block_swizzle_size: Scalar[DType.uint], rasterize_order: RasterOrder, swapAB: Bool = False, k_group_size: Scalar[DType.uint] = 1, num_accum_pipeline_stages: Scalar[DType.uint] = 2, num_clc_pipeline_stages: Scalar[DType.uint] = 2, num_split_k: Int = 1) -> Self` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String)
--- ## tuning_configs
## Structs * [​`TuningConfigSM100`](./TuningConfigSM100):
--- ## compute_total_tiles
`compute_total_tiles[tile_m: Int, tile_n: Int, max_groups: Int](problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int) -> Int` Compute total number of tiles across all groups. **Args:** * ​problem\_sizes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (num\_groups, 4) tensor with \[M, N, K, L] per group. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of GEMM problems. **Returns:** [`Int`](/mojo/std/builtin/int/Int): Total tile count.
--- ## grouped_block_scaled_matmul
`grouped_block_scaled_matmul[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, max_groups: Int, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, total_tiles: Int, a_template: LayoutTensor[a_type, a_layout, MutAnyOrigin], b_template: LayoutTensor[b_type, b_layout, MutAnyOrigin], c_template: LayoutTensor[c_type, c_layout, MutAnyOrigin], sfa_template: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], sfb_template: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext)` Launch grouped block-scaled FP8 matmul kernel on SM100. Computes C\[g] = scale(A\[g]) @ scale(B\[g]) for g in range(num\_groups), where each group can have different M, N, K dimensions. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output element type. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Output tensor layout. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): A matrix element type (FP8). * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A tensor layout. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): B matrix element type (FP8). * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): B tensor layout. * ​sfa\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): A scaling factor type (F8-UE8M0). * ​sfa\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A scaling factor tensor layout. * ​sfb\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): B scaling factor type (F8-UE8M0). * ​sfb\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): B scaling factor tensor layout. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether B is transposed (must be True). * ​max\_groups ([`Int`](/mojo/std/builtin/int/Int)): Maximum number of groups (compile-time bound). * ​config ([`BlockScaledMatmulConfig`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/config/BlockScaledMatmulConfig)): Block-scaled matmul configuration. * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional epilogue lambda for element-wise operations on output. Applied after matmul, before writing to global memory. * ​register\_based\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): If True (default), apply epilogue in registers. If False, use SMEM-based epilogue path. **Args:** * ​a\_ptrs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group A matrix pointers (max\_groups, 1). * ​b\_ptrs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group B matrix pointers (max\_groups, 1). * ​c\_ptrs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group C matrix pointers (max\_groups, 1). * ​sfa\_ptrs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group A scaling factor pointers (max\_groups, 1). * ​sfb\_ptrs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group B scaling factor pointers (max\_groups, 1). * ​problem\_sizes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-group problem sizes (max\_groups, 4) as \[M, N, K, L]. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Actual number of groups (runtime value <= max\_groups). * ​total\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Total tiles across all groups (computed by caller). * ​a\_template ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Template A tensor for TMA descriptor creation. * ​b\_template ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Template B tensor for TMA descriptor creation. * ​c\_template ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Template C tensor for TMA descriptor creation. * ​sfa\_template ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Template SFA tensor for TMA descriptor creation. * ​sfb\_template ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Template SFB tensor for TMA descriptor creation. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel launch. **Raises:** If configuration constraints are violated.
--- ## grouped_smem_size
`grouped_smem_size[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]() -> Int` Calculate shared memory size for grouped block-scaled kernel. **Returns:** [`Int`](/mojo/std/builtin/int/Int): SMEM size in bytes, including tensormap descriptor storage.
--- ## grouped_block_scaled_matmul (Grouped_block_scaled_matmul)
CPU entry points for grouped block-scaled SM100 matmul. Supports multiple GEMM operations with variable problem sizes per group. Uses TMATensorTileArray for per-block updatable TMA descriptors. This module implements grouped block-scaled GEMM following the architecture of NVIDIA CuTe DSL grouped\_blockscaled\_gemm.py: 1. Creates template TMA descriptors from the first group 2. Creates TMATensorTileArray with one tensormap per block 3. Launches GroupedBlockScaledMatmulKernel with per-group pointers Usage: \# Per-group pointers (device addresses) var a\_ptrs = ... # (num\_groups, 1) with uint64 addresses var b\_ptrs = ... # (num\_groups, 1) var c\_ptrs = ... # (num\_groups, 1) var sfa\_ptrs = ... # (num\_groups, 1) var sfb\_ptrs = ... # (num\_groups, 1) ``` # Problem sizes per group var problem_sizes = ... # (num_groups, 4) with [M, N, K, L] grouped_block_scaled_matmul[...]( a_ptrs, b_ptrs, c_ptrs, sfa_ptrs, sfb_ptrs, problem_sizes, num_groups, ctx ) ``` ## Functions * [​`compute_total_tiles`](./compute_total_tiles): Compute total number of tiles across all groups. * [​`grouped_block_scaled_matmul`](./grouped_block_scaled_matmul): Launch grouped block-scaled FP8 matmul kernel on SM100. * [​`grouped_smem_size`](./grouped_smem_size): Calculate shared memory size for grouped block-scaled kernel. * [​`validate_grouped_gemm_constraints`](./validate_grouped_gemm_constraints): Validate grouped GEMM configuration constraints.
--- ## validate_grouped_gemm_constraints
`validate_grouped_gemm_constraints[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]()` Validate grouped GEMM configuration constraints. Constraints from NVIDIA CuTe DSL grouped\_blockscaled\_gemm.py: * MMA tiler M: 128 or 256 * MMA tiler N: 128 or 256 * Cluster M/N: Power of 2, <=4 per axis (for SF multicast) * Total cluster size: <=16 * 16-byte alignment on contiguous dimensions
--- ## GroupedBlockScaledMatmulKernel
`struct GroupedBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], max_groups: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]` Grouped block-scaled matmul kernel with dynamic tensormap updates. This kernel extends BlackwellBlockScaledMatmulKernel to support grouped GEMM: * Uses GroupedTileScheduler for linear tile iteration across groups * Uses GroupedTensormapManager for per-block updatable TMA descriptors * Updates tensormaps when transitioning between groups Architecture (aligned with NVIDIA CuTe DSL grouped\_blockscaled\_gemm.py): * TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions * MMA warp: Waits for tensormap init, consumes tiles, performs block-scaled MMA * Epilogue warps: Initializes C tensormap, handles C group transitions ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_smem_layout.size() * size_of[a_type]())` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.a_swizzle]()` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[1].value()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = DType.float32` ### `b_expected_bytes` `comptime b_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_smem_layout.size() * size_of[b_type]())` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]()` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[1].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputN)` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]` ### `GroupPtrLayout` `comptime GroupPtrLayout = Layout.row_major(max_groups, 1)` ### `input_expected_bytes` `comptime input_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)` ### `InputTilePipelineType` `comptime InputTilePipelineType = InputTilePipeline[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS), 1]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages_2sm` `comptime num_clc_pipeline_stages_2sm = 2` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `ProblemSizesLayout` `comptime ProblemSizesLayout = Layout.row_major(max_groups, 4)` ### `SCHEDULER_THREADS` `comptime SCHEDULER_THREADS = WARP_SIZE` ### `SchedulerType` `comptime SchedulerType = GroupedTileScheduler[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, max_groups]` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE = (4 * config)` ### `sfa_expected_bytes` `comptime sfa_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_smem_layout.size() * size_of[sfa_dtype]())` ### `SFA_NUM_COLS` `comptime SFA_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))` ### `sfa_smem_layout` `comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `sfb_expected_bytes` `comptime sfb_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_smem_layout.size() * size_of[sfb_dtype]())` ### `SFB_NUM_COLS` `comptime SFB_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))` ### `sfb_smem_layout` `comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SmemType` `comptime SmemType = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N` ### `TENSORMAP_AB_INIT_BARRIER_ID` `comptime TENSORMAP_AB_INIT_BARRIER_ID = 3` ### `TENSORMAP_AB_INIT_THREADS` `comptime TENSORMAP_AB_INIT_THREADS = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS)` ### `TensormapAbInitBarrier` `comptime TensormapAbInitBarrier = WarpGroupBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_BARRIER_ID]` ### `TensormapManagerType` `comptime TensormapManagerType = GroupedTensormapManager` ### `TilePayload` `comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue, True]` ### `TMA_LOAD_THREADS` `comptime TMA_LOAD_THREADS = WARP_SIZE` ### `TMATensorTileArrayA` `comptime TMATensorTileArrayA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, a_layout, a_desc_layout]` ### `TMATensorTileArrayB` `comptime TMATensorTileArrayB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, b_layout, b_desc_layout]` ### `TMATensorTileArrayC` `comptime TMATensorTileArrayC = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, c_layout, c_desc_layout]` ### `TMATensorTileArraySFA` `comptime TMATensorTileArraySFA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, sfa_layout, sfa_desc_layout]` ### `TMATensorTileArraySFB` `comptime TMATensorTileArraySFB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, sfb_layout, sfb_desc_layout]` ### `Tmem` `comptime Tmem = TmemAllocation[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `TmemRegion` `comptime TmemRegion = BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]` ## Methods ### `validate_config` `static validate_config()` Compile-time validation of kernel configuration. ### `run` `static run(a_tma_template: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_template: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_template: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_template: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_template: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, a_layout, a_desc_layout], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, b_layout, b_desc_layout], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, sfa_layout, sfa_desc_layout], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, sfb_layout, sfb_desc_layout], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, c_layout, c_desc_layout], group_a_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_b_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_c_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_sfa_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_sfb_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], problem_sizes: LayoutTensor[DType.int32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ProblemSizesLayout, MutAnyOrigin], num_groups: Int)` Grouped block-scaled GEMM kernel entry point. This kernel processes multiple GEMM problems (groups) with dynamic tensormap updates at group boundaries. ### `load_input_tiles` `static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)` Load A, B, SFA, SFB tiles using TMA with InputProducerStage. ### `mma` `static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)` Execute MMA operations using InputConsumerStage. ### `epilogue` `static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], stage: OutputStage[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32 = 1)` Execute epilogue to store accumulated results. ### `run_2sm` `static run_2sm(a_tma_template: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_template: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_template: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_template: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_template: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, a_layout, a_desc_layout], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, b_layout, b_desc_layout], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, sfa_layout, sfa_desc_layout], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, sfb_layout, sfb_desc_layout], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, c_layout, c_desc_layout], group_a_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_b_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_c_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_sfa_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], group_sfb_ptrs: LayoutTensor[DType.uint64, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].GroupPtrLayout, MutAnyOrigin], problem_sizes: LayoutTensor[DType.int32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ProblemSizesLayout, MutAnyOrigin], num_groups: Int)` Grouped block-scaled GEMM kernel with 2SM (cta\_group=2) support. This entry point uses CLC-based work distribution for proper 2SM synchronization between CTAs in a cluster. Both CTAs cooperate on each tile, with one CTA doing MMA work and both doing TMA loads. Architecture matches the working block\_scaled\_matmul\_kernel: * Scheduler warp: Produces work items via CLC barriers * TMA warp: Loads tiles with tensormap updates on group change * MMA warp: Waits on CLC, executes MMA (elected CTA only) * Epilogue warps: Stores results with tensormap updates
--- ## GroupedTensormapManager
`@register_passable(trivial)` `struct GroupedTensormapManager` Manages tensormap SMEM state and updates for grouped GEMM. Handles the 4-step CuTe DSL update pattern: 1. tensormap\_fence\_acquire() - Acquire fence on block's GMEM tensormap 2. replace\_tensormap\_global\_address\_in\_shared\_mem() - Update SMEM descriptor 3. tensormap\_cp\_fence\_release() - Copy SMEM -> block's GMEM tensormap 4. syncwarp() - Sync before using updated tensormap TMA descriptor arrays are passed by reference (as UnsafePointer from TMATensorTileArray\[blk]) to methods rather than stored by value. This ensures PTX tensormap operations receive valid GMEM addresses with correct address space semantics. The manager stores only SMEM descriptor pointers, which are shared across all warps within a CTA. ## Fields * ​smem (`GroupedTensormapSmem`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `init_ab_tensormaps` `init_ab_tensormaps[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout](self, template_a: TMATensorTile[a_dtype, a_layout, a_desc], template_b: TMATensorTile[b_dtype, b_layout, b_desc], template_sfa: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], template_sfb: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc])` Initialize A/B/SFA/SFB tensormaps in SMEM from grid-constant templates. Called by MMA warp (lane 0). Copies template descriptors to SMEM. Templates must be kernel parameters with nvvm.grid\_constant metadata. ### `init_c_tensormap` `init_c_tensormap[c_dtype: DType, c_layout: Layout, c_desc: Layout](self, template_c: TMATensorTile[c_dtype, c_layout, c_desc])` Initialize C tensormap in SMEM from grid-constant template. Called by epilogue warp (lane 0). Copies template descriptor to SMEM. ### `update_ab_for_group` `update_ab_for_group[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_a: UnsafePointer[TMATensorTile[a_dtype, a_layout, a_desc], MutAnyOrigin], tma_b: UnsafePointer[TMATensorTile[b_dtype, b_layout, b_desc], MutAnyOrigin], tma_sfa: UnsafePointer[TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], MutAnyOrigin], tma_sfb: UnsafePointer[TMATensorTile[sfb_dtype, sfb_layout, sfb_desc], MutAnyOrigin])` Update A/B/SFA/SFB tensormaps for the specified group. Called when group\_changed=True in TMA load warp. TMA pointers must be from TMATensorTileArray\[block\_idx.x] (GMEM). ### `update_c_for_group` `update_c_for_group[c_dtype: DType, c_layout: Layout, c_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_c: UnsafePointer[TMATensorTile[c_dtype, c_layout, c_desc], MutAnyOrigin])` Update C tensormap for the specified group. Called when group\_changed=True in epilogue warp. TMA pointer must be from TMATensorTileArray\[block\_idx.x] (GMEM).
--- ## GroupedTensormapSmem
`@register_passable(trivial)` `struct GroupedTensormapSmem` Shared memory pointers for tensormap descriptors. Points to 5 TMA descriptors (128 bytes each) in SMEM for dynamic updates: * A, B, SFA, SFB for input loading * C for output storing These pointers should come from the main SMEM struct (GroupedBlockScaledSmem) to ensure all warps access the same SMEM locations. ## Fields * ​desc\_a (`UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]`): * ​desc\_b (`UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]`): * ​desc\_sfa (`UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]`): * ​desc\_sfb (`UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]`): * ​desc\_c (`UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `from_smem` `static from_smem(ptr_a: UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED], ptr_b: UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED], ptr_sfa: UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED], ptr_sfb: UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED], ptr_c: UnsafePointer[TMADescriptor, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self` Create tensormap pointers from explicit SMEM pointers. **Args:** * ​ptr\_a ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to A tensormap in SMEM. * ​ptr\_b ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to B tensormap in SMEM. * ​ptr\_sfa ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to SFA tensormap in SMEM. * ​ptr\_sfb ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to SFB tensormap in SMEM. * ​ptr\_c ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to C tensormap in SMEM. **Returns:** `Self`: GroupedTensormapSmem with the provided pointers.
--- ## grouped_block_scaled_matmul_kernel
Grouped block-scaled SM100 matmul kernel for multiple GEMM problems. This kernel extends the block\_scaled\_matmul\_kernel to support grouped GEMM with variable problem sizes per group. It uses: 1. GroupedTileScheduler: For linear tile iteration across groups 2. TMATensorTileArray: For per-block updatable TMA descriptors 3. Dynamic tensormap updates: When transitioning between groups Architecture (aligned with NVIDIA CuTe DSL grouped\_blockscaled\_gemm.py): * TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions * MMA warp: Consumes input tiles, performs block-scaled MMA * Epilogue warps: Initializes C tensormap, handles C group transitions * Named barrier synchronization between warps for tensormap init Key differences from block\_scaled\_matmul\_kernel.mojo: 1. TMA descriptors are per-block (TMATensorTileArray) not grid constants 2. SMEM tensormap buffers for dynamic updates (5 x 128 bytes) 3. GroupedWorkInfo provides group\_idx, k\_tile\_count, group\_changed 4. When group\_changed=True, tensormaps are updated before loading tiles 5. K-loop uses per-group k\_tile\_count instead of global K dimension ## `comptime` values ### `NUM_TENSORMAPS` `comptime NUM_TENSORMAPS = 5` ### `TMA_DESCRIPTOR_SIZE` `comptime TMA_DESCRIPTOR_SIZE = 128` ## Structs * [​`GroupedBlockScaledMatmulKernel`](./GroupedBlockScaledMatmulKernel): Grouped block-scaled matmul kernel with dynamic tensormap updates. * [​`GroupedTensormapManager`](./GroupedTensormapManager): Manages tensormap SMEM state and updates for grouped GEMM. * [​`GroupedTensormapSmem`](./GroupedTensormapSmem): Shared memory pointers for tensormap descriptors. ## Functions * [​`is_valid_dtypes_and_scale_factor_vec_size`](./is_valid_dtypes_and_scale_factor_vec_size): Check if dtypes and sf\_vec\_size are valid combinations. * [​`is_valid_mma_tiler_and_cluster_shape`](./is_valid_mma_tiler_and_cluster_shape): Check if MMA tiler and cluster shape are valid.
--- ## is_valid_dtypes_and_scale_factor_vec_size
`is_valid_dtypes_and_scale_factor_vec_size(ab_dtype: DType, sf_dtype: DType, sf_vec_size: Int, c_dtype: DType) -> Bool` Check if dtypes and sf\_vec\_size are valid combinations. Valid combinations (from NVIDIA CuTe DSL grouped\_blockscaled\_gemm.py): * MXF8: Float8E5M2/Float8E4M3FN + Float8E8M0FNU + sf\_vec\_size=32 * MXF4: Float4E2M1FN + Float8E8M0FNU + sf\_vec\_size=32 * NVF4: Float4E2M1FN + Float8E8M0FNU/Float8E4M3FN + sf\_vec\_size=16 **Args:** * ​ab\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of A and B matrices. * ​sf\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of scale factors. * ​sf\_vec\_size ([`Int`](/mojo/std/builtin/int/Int)): The vector size of scale factors (16 or 32). * ​c\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The data type of the output matrix. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the combination is valid.
--- ## is_valid_mma_tiler_and_cluster_shape
`is_valid_mma_tiler_and_cluster_shape(mma_tiler_m: Int, mma_tiler_n: Int, cluster_m: Int, cluster_n: Int) -> Bool` Check if MMA tiler and cluster shape are valid. Constraints (from NVIDIA CuTe DSL): * MMA tiler M: 128 or 256 * MMA tiler N: 128 or 256 * Cluster M must be multiple of 2 if MMA tiler M is 256 * Cluster M/N: Power of 2, <=4 per axis (for SF multicast) * Total cluster size: <=16 **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## GroupedBlockScaledSmem
`struct GroupedBlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]` SMEM struct for grouped block-scaled GEMM. Extends standard BlockScaledSmem with: * 5 TMA descriptor slots for dynamic tensormap updates (A, B, SFA, SFB, C) * Each descriptor is 128 bytes with 128-byte alignment Layout in SMEM: 1. Tensormap descriptors (5 x 128 bytes = 640 bytes) 2. A tiles 3. B tiles 4. C tiles 5. SFA tiles 6. SFB tiles 7. Pipeline barriers 8. CLC barriers 9. TMEM state ## Fields * ​tiles (`GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles`): * ​pipelines (`GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines`): * ​tensormap\_a (`TMADescriptor`): * ​tensormap\_b (`TMADescriptor`): * ​tensormap\_sfa (`TMADescriptor`): * ​tensormap\_sfb (`TMADescriptor`): * ​tensormap\_c (`TMADescriptor`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout` ### `ATileArray` `comptime ATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `Pipelines` `comptime Pipelines = SmemPipelineBundle[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]` ### `SF_BK` `comptime SF_BK = sf_bk[config]()` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()` ### `SFA_DIM0` `comptime SFA_DIM0 = sfa_dim0[config]()` ### `SFA_DIM1` `comptime SFA_DIM1 = sfa_dim1[config]()` ### `sfa_smem_layout` `comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFATileArray` `comptime SFATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray` ### `SFB_DIM0` `comptime SFB_DIM0 = sfb_dim0[config]()` ### `SFB_DIM1` `comptime SFB_DIM1 = sfb_dim1[config]()` ### `sfb_smem_layout` `comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFBTileArray` `comptime SFBTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray` ### `Tiles` `comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `GroupedBlockScaledSmem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `GroupedBlockScaledSmem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray` Get C tile array accessor (TileTensor-based). **Returns:** `GroupedBlockScaledSmem` ### `sfa_tiles` `sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray` Get SFA tile array accessor (TileTensor-based). **Returns:** `GroupedBlockScaledSmem` ### `sfb_tiles` `sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray` Get SFB tile array accessor (TileTensor-based). **Returns:** `GroupedBlockScaledSmem` ### `tensormap_storage_size` `static tensormap_storage_size() -> Int` Size of tensormap storage in bytes (5 x 128 = 640 bytes). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+SFA+SFB+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## grouped_block_scaled_smem
Shared memory layout for grouped block-scaled SM100 matmul. Extends BlockScaledSmem with tensormap descriptor storage for dynamic updates. Used by GroupedBlockScaledMatmulKernel for grouped GEMM with variable problem sizes. Additional SMEM allocations: * 5 TMA descriptors (A, B, SFA, SFB, C) at 128 bytes each = 640 bytes total * Aligned to 128 bytes for TMA descriptor requirements ## `comptime` values ### `NUM_GROUPED_TENSORMAPS` `comptime NUM_GROUPED_TENSORMAPS = 5` ### `TMA_DESCRIPTOR_BYTES` `comptime TMA_DESCRIPTOR_BYTES = 128` ## Structs * [​`GroupedBlockScaledSmem`](./GroupedBlockScaledSmem): SMEM struct for grouped block-scaled GEMM.
--- ## GroupedAdvanceContext
`@register_passable(trivial)` `struct GroupedAdvanceContext[work_origin: MutOrigin, idx_origin: MutOrigin]` Context manager that returns current work and advances on exit. This follows the same pattern as the working kernel's WaitAndAdvanceContext: * Pre-compute next work during construction * **enter** returns current work for processing * **exit** assigns pre-computed next work and updates linear index Usage: with work\_iter.next() as current: \# Process current work \# After: work\_iter.work\_info updated to next work ## Fields * ​work\_info\_ptr (`Pointer[GroupedWorkInfo, work_origin]`): * ​linear\_idx\_ptr (`Pointer[UInt32, idx_origin]`): * ​next\_work (`GroupedWorkInfo`): * ​next\_linear\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[GroupedWorkInfo, work_origin], linear_idx_ptr: Pointer[UInt32, idx_origin], next_work: GroupedWorkInfo, next_linear_idx: UInt32) -> Self` ### `__enter__` `__enter__(self) -> GroupedWorkInfo` **Returns:** [`GroupedWorkInfo`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/grouped_block_scaled/grouped_tile_scheduler/GroupedWorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## GroupedCLCSchedulerIterator
`@register_passable(trivial)` `struct GroupedCLCSchedulerIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]` Scheduler warp iterator for grouped GEMM with CLC. The scheduler warp produces work items for other warps via CLC. It iterates through all tiles across all groups and signals CLC barriers. Usage: var sched\_iter = scheduler.scheduler\_iterator() while sched\_iter.has\_work(): with sched\_iter.next(): sched\_iter.signal\_and\_advance() sched\_iter.drain() ## Fields * ​work\_info (`GroupedWorkInfo`): Current work item. * ​linear\_tile\_idx (`UInt32`): Current linear tile index. * ​consumer\_state (`PipelineState[num_clc_stages]`): * ​producer\_state (`PipelineState[num_clc_stages]`): * ​throttle\_pipeline (`GroupedCLCSchedulerIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline`): * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): * ​cumulative\_tiles (`StaticTuple[UInt32, (max_groups + 1)]`): * ​problem\_m (`StaticTuple[UInt32, max_groups]`): * ​problem\_n (`StaticTuple[UInt32, max_groups]`): * ​problem\_k (`StaticTuple[UInt32, max_groups]`): * ​num\_groups (`UInt32`): * ​total\_tiles (`UInt32`): * ​signal\_count (`UInt32`): Number of signals sent (for pipeline fill tracking). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ThrottlePipeline` `comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]` ## Methods ### `__init__` `__init__(problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_response: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo) -> Self` Initialize scheduler iterator. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedAdvanceContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.linear_tile_idx)]` Get context manager for advance-after-work pattern. **Returns:** `GroupedAdvanceContext` ### `signal_and_advance` `signal_and_advance(mut self)` Signal CLC throttle and produce next work request. This is called inside the work loop after processing current work. It signals that we've consumed the throttle and produces the next work item for all CTAs. NOTE: We skip the throttle\_pipeline.consumer\_signal\_and\_step() call that the hardware CLC version uses. For software CLC simulation, the clc\_full/clc\_empty barriers provide sufficient synchronization. The throttle pattern causes a deadlock because: * Scheduler waits for TMA Load via throttle full barrier * TMA Load waits for Scheduler via throttle empty barrier * Both block on first iteration since barriers start at phase 0 ### `drain` `drain(mut self)` Drain all pending CLC requests before kernel exit. Only waits for slots that were actually signaled to avoid deadlock when workload is smaller than pipeline depth. Note: After signaling, producer\_state has stepped to the NEXT stage. We need to wait on stages 0..slots\_to\_drain-1, not from producer\_state.
--- ## GroupedCLCWaitAndAdvanceContext
`@register_passable(trivial)` `struct GroupedCLCWaitAndAdvanceContext[work_origin: MutOrigin]` Context for waiting on CLC barrier and advancing work iterator. Encapsulates CLC response barrier synchronization: * Construction: Waits for CLC response, fetches next work * **enter**: Returns current work\_info for processing * **exit**: Assigns fetched work as current Usage: with work\_iter.wait\_and\_advance() as current: \# current is the work item to process NOW process(current) \# After exit, work\_iter.work\_info is the NEXT work item ## Fields * ​work\_info\_ptr (`Pointer[GroupedWorkInfo, work_origin]`): * ​next\_work (`GroupedWorkInfo`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[GroupedWorkInfo, work_origin], next_work: GroupedWorkInfo) -> Self` ### `__enter__` `__enter__(self) -> GroupedWorkInfo` **Returns:** [`GroupedWorkInfo`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/grouped_block_scaled/grouped_tile_scheduler/GroupedWorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## GroupedCLCWorkIterator
`@register_passable(trivial)` `struct GroupedCLCWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]` Per-warp work iterator for grouped GEMM with CLC barrier support. This iterator combines grouped GEMM features with CLC-based synchronization for 2SM support. It uses CLC barriers to ensure both CTAs in a cluster process the same tile at the same time. Key features: * Uses CLC barriers for inter-CTA synchronization (like working kernel) * Tracks group\_idx, k\_tile\_count, group\_changed (like grouped scheduler) * wait\_and\_advance() actually waits on CLC barriers Usage: var work\_iter = scheduler.clc\_work\_iterator() while work\_iter.has\_work(): with work\_iter.wait\_and\_advance() as current: if current.group\_changed: update\_tensormaps(current.group\_idx) process\_tile(current) ## Fields * ​work\_info (`GroupedWorkInfo`): Current work item. * ​consumer\_state (`PipelineState[num_clc_stages]`): CLC consumer pipeline state. * ​throttle\_pipeline (`GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline`): Throttle pipeline for load/scheduler sync. * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): CLC full barriers (signaled by scheduler when work is ready). * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): CLC empty barriers (signaled by workers when done). * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): CLC response storage (contains work info). * ​cumulative\_tiles (`StaticTuple[UInt32, (max_groups + 1)]`): Cumulative tile count at the start of each group. * ​problem\_m (`StaticTuple[UInt32, max_groups]`): M dimension for each group. * ​problem\_n (`StaticTuple[UInt32, max_groups]`): N dimension for each group. * ​problem\_k (`StaticTuple[UInt32, max_groups]`): K dimension for each group. * ​num\_groups (`UInt32`): Number of active groups. * ​total\_tiles (`UInt32`): Total tiles across all groups. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ThrottlePipeline` `comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]` ## Methods ### `__init__` `__init__(problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_response: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo) -> Self` Initialize CLC work iterator. **Args:** * ​problem\_sizes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (num\_groups, 4) tensor with \[M, N, K, L] per group. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of active groups. * ​full\_mbar ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): CLC full barrier pointer. * ​empty\_mbar ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): CLC empty barrier pointer. * ​clc\_response ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): CLC response storage pointer. * ​throttle\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Throttle pipeline barrier pointer. * ​initial\_work ([`GroupedWorkInfo`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/grouped_block_scaled/grouped_tile_scheduler/GroupedWorkInfo)): Initial work item (first tile). ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `wait_and_advance` `wait_and_advance[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedCLCWaitAndAdvanceContext[origin_of(state_origin._mlir_origin.work_info)]` Wait for next work from CLC and advance iterator. This method waits on CLC full barriers to synchronize all CTAs in the cluster before advancing to the next work item. Usage: with work\_iter.wait\_and\_advance() as current: \# Process current work item \# After exit, work\_iter points to next work **Returns:** `GroupedCLCWaitAndAdvanceContext` ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedAdvanceContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.total_tiles)]` Get context manager for advance-after-work pattern. Does NOT wait on CLC - use wait\_and\_advance() for MMA warp. **Returns:** `GroupedAdvanceContext` ### `throttle_signal` `throttle_signal(mut self, is_first_cta_in_cluster: Bool)` Signal CLC throttle if this is the first CTA in cluster. NOTE: For software CLC simulation, this is a no-op. The throttle pattern causes a deadlock because both Scheduler and TMA Load wait on each other's barriers on the first iteration. The CLC full/empty barriers provide sufficient synchronization without the throttle. **Args:** * ​is\_first\_cta\_in\_cluster ([`Bool`](/mojo/std/builtin/bool/Bool)): Only first CTA signals to avoid duplicates.
--- ## GroupedTileScheduler
`@register_passable(trivial)` `struct GroupedTileScheduler[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_stages: Int = 0, cta_group: Int = 1]` Tile scheduler for grouped block-scaled GEMM. Uses linear tile iteration to map tiles across groups. Does not use CLC (Cluster Launch Control) since work distribution is deterministic. ## Parameters * ​tile\_m ([`Int`](/mojo/std/builtin/int/Int)): M dimension of output tiles. * ​tile\_n ([`Int`](/mojo/std/builtin/int/Int)): N dimension of output tiles. * ​tile\_k ([`Int`](/mojo/std/builtin/int/Int)): K dimension of input tiles. * ​max\_groups ([`Int`](/mojo/std/builtin/int/Int)): Maximum number of groups. * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stages (0 = single wave). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Number of CTAs cooperating per tile (1 or 2 for 2SM). ## Fields * ​num\_groups (`Int`): Number of active groups. * ​problem\_sizes (`LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin]`): Problem sizes tensor (num\_groups, 4) with \[M, N, K, L] per group. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int) -> Self` Initialize scheduler with problem sizes. **Args:** * ​problem\_sizes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (num\_groups, 4) tensor with \[M, N, K, L] per group. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of active groups. ### `work_iterator` `work_iterator(self) -> GroupedWorkIterator[tile_m, tile_n, tile_k, max_groups, cta_group]` Create a per-warp work iterator. Each warp should create its own work iterator. The iterator owns work\_info and cumulative tile counts internally. For 2SM (cta\_group=2), the iterator uses cluster-based indexing. **Returns:** `GroupedWorkIterator` ### `total_tiles` `total_tiles(self) -> Int` Compute total number of tiles across all groups. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## GroupedWorkInfo
`@register_passable(trivial)` `struct GroupedWorkInfo` Work info for grouped GEMM with group-specific metadata. Extends the base WorkInfo with: * group\_idx: Current group index * k\_tile\_count: Number of K tiles for this group * group\_changed: True if group changed since last tile (triggers tensormap update) ## Fields * ​m (`UInt32`): M-coordinate of tile within current group. * ​n (`UInt32`): N-coordinate of tile within current group. * ​k\_start (`UInt32`): Starting K index (always 0 for grouped GEMM). * ​is\_valid\_tile (`Bool`): Whether this work tile is valid (not OOB). * ​group\_idx (`UInt32`): Current group index. * ​k\_tile\_count (`UInt32`): Number of K tiles for this group. * ​group\_changed (`Bool`): True if group changed since last tile (triggers tensormap update). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Create an invalid/empty work info. ### `is_valid` `is_valid(self) -> Bool` Check if this work tile is valid. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `coord` `coord(self) -> Tuple[UInt, UInt]` Get (m, n) tile coordinates as a tuple. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## GroupedWorkIterator
`@register_passable(trivial)` `struct GroupedWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, cta_group: Int = 1]` Per-warp work iterator for grouped GEMM. This iterator traverses tiles across all groups, tracking when groups change to trigger tensormap updates. It uses linear iteration instead of CLC. For 2SM (cta\_group=2), both CTAs in a cluster work on the same logical tile. The cluster index (block\_idx.x // cta\_group) is used for tile assignment, and advance step is grid\_dim.x // cta\_group (number of clusters). Usage: var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): var current = work\_iter.current() if current.group\_changed: update\_tensormaps(current.group\_idx) process\_tile(current) work\_iter.advance() ## Fields * ​work\_info (`GroupedWorkInfo`): Current work item. * ​linear\_tile\_idx (`UInt32`): Current linear tile index (across all groups). * ​total\_tiles (`UInt32`): Total number of tiles across all groups. * ​prev\_group\_idx (`UInt32`): Previous group index for detecting group changes. * ​cumulative\_tiles (`StaticTuple[UInt32, (max_groups + 1)]`): Cumulative tile count at the start of each group. * ​problem\_m (`StaticTuple[UInt32, max_groups]`): M dimension for each group. * ​problem\_n (`StaticTuple[UInt32, max_groups]`): N dimension for each group. * ​problem\_k (`StaticTuple[UInt32, max_groups]`): K dimension for each group. * ​num\_groups (`UInt32`): Number of active groups. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, grid_size: UInt32) -> Self` Initialize work iterator with problem sizes. **Args:** * ​problem\_sizes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (num\_groups, 4) tensor with \[M, N, K, L] per group. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of active groups. * ​grid\_size ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Number of blocks in the grid. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `current` `current(self) -> GroupedWorkInfo` Get current work item. **Returns:** [`GroupedWorkInfo`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/grouped_block_scaled/grouped_tile_scheduler/GroupedWorkInfo) ### `advance` `advance(mut self)` Advance to next tile. ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedAdvanceContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.linear_tile_idx)]` Get context manager that returns current work and advances on exit. Compatible with the working kernel's pattern: with work\_iter.next() as current: process\_tile(current) \# After: work\_iter.work\_info updated to next work Pre-computes next state, then on **exit** updates work\_info and linear\_idx. **Returns:** `GroupedAdvanceContext` ### `wait_and_advance` `wait_and_advance[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedAdvanceContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.linear_tile_idx)]` Same as next() - no CLC waiting for grouped GEMM. For compatibility with MMA warp pattern. Since we don't use CLC, this behaves identically to next(). **Returns:** `GroupedAdvanceContext`
--- ## grouped_tile_scheduler
Grouped tile scheduler for SM100 structured block-scaled GEMM. This scheduler extends the SM100 TileScheduler to support grouped GEMM with variable problem sizes per group. It uses linear tile iteration instead of CLC (Cluster Launch Control) to map a global linear tile index to group-specific coordinates. Key features: * GroupedWorkInfo: Extends WorkInfo with group\_idx, k\_tile\_count, group\_changed * delinearize\_to\_group(): Maps linear tile index to group + local coordinates * Supports variable M, N, K per group * Compatible with dynamic tensormap updates Usage: var scheduler = GroupedTileScheduler\[...]\(problem\_sizes, tile\_shape) var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): with work\_iter.next() as current: if current.group\_changed: update\_tensormaps(current.group\_idx) process\_tile(current) ## Structs * [​`GroupedAdvanceContext`](./GroupedAdvanceContext): Context manager that returns current work and advances on exit. * [​`GroupedCLCSchedulerIterator`](./GroupedCLCSchedulerIterator): Scheduler warp iterator for grouped GEMM with CLC. * [​`GroupedCLCWaitAndAdvanceContext`](./GroupedCLCWaitAndAdvanceContext): Context for waiting on CLC barrier and advancing work iterator. * [​`GroupedCLCWorkIterator`](./GroupedCLCWorkIterator): Per-warp work iterator for grouped GEMM with CLC barrier support. * [​`GroupedTileScheduler`](./GroupedTileScheduler): Tile scheduler for grouped block-scaled GEMM. * [​`GroupedWorkInfo`](./GroupedWorkInfo): Work info for grouped GEMM with group-specific metadata. * [​`GroupedWorkIterator`](./GroupedWorkIterator): Per-warp work iterator for grouped GEMM.
--- ## grouped_block_scaled
Grouped block-scaled matmul kernel for SM100. ## Modules * [​`grouped_block_scaled_matmul`](./grouped_block_scaled_matmul/): CPU entry points for grouped block-scaled SM100 matmul. * [​`grouped_block_scaled_matmul_kernel`](./grouped_block_scaled_matmul_kernel/): Grouped block-scaled SM100 matmul kernel for multiple GEMM problems. * [​`grouped_block_scaled_smem`](./grouped_block_scaled_smem/): Shared memory layout for grouped block-scaled SM100 matmul. * [​`grouped_tile_scheduler`](./grouped_tile_scheduler/): Grouped tile scheduler for SM100 structured block-scaled GEMM.
--- ## grouped_matmul_1d1d_nvfp4
`grouped_matmul_1d1d_nvfp4[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, b_type: DType, b_layout: Layout, expert_ids_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, _sfb_layout: Layout, expert_scale_layout: Layout, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], _b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], _b_scales: LayoutTensor[sfb_dtype, _sfb_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scale_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)` Launch grouped 1D-1D block-scaled matmul kernel. This function sets up TMA descriptors and launches the kernel with the proper configuration for 1D-1D tensor layout. **Args:** * ​c\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor (total\_tokens, N). * ​a\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input A tensor (total\_tokens, K). * ​a\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert offsets (num\_active\_experts + 1). * ​a\_scale\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert scale offsets (num\_active\_experts). * ​\_b\_device ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Weight tensor B (num\_experts, N, K). * ​expert\_ids ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Active expert IDs (num\_active\_experts). * ​a\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale factors for A (5D). * ​\_b\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale factors for B (6D). * ​expert\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert output scaling (num\_experts). * ​num\_active\_experts ([`Int`](/mojo/std/builtin/int/Int)): Number of active experts. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context.
--- ## grouped_matmul_dynamic_scaled_nvfp4 (Grouped_1d1d_matmul)
`grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)` Performs grouped matrix multiplication with NVFP4 quantization. This is a compatibility wrapper that matches the old API from grouped\_matmul\_sm100\_1d1d.mojo. It creates the default config and calls the new structured kernel implementation. Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8\_e4m3fn scale factors. **Parameters:** * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output tensor dtype. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Output tensor layout. * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Input A dtype (must be uint8 for packed NVFP4). * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Input A layout. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Input B dtype (must be uint8 for packed NVFP4). * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Input B layout. * ​scales\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Scale factor dtype. * ​a\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A scales layout. * ​b\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): B scales layout. * ​a\_offsets\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A offsets layout. * ​a\_scale\_offsets\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): A scale offsets layout. * ​expert\_ids\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Expert IDs layout. * ​expert\_scales\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Expert scales layout. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether B is transposed (must be True). * ​target ([`StringSlice`](/mojo/std/collections/string/string_slice/StringSlice)): Target device (ignored, always runs on GPU). **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor (total\_tokens, N). * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input A tensor (total\_tokens, K). * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Weight tensor B (num\_experts, N, K). * ​a\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale factors for A. * ​b\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale factors for B. * ​a\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert token offsets. * ​a\_scale\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert scale offsets. * ​expert\_ids ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Active expert IDs. * ​expert\_scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Per-expert output scaling. * ​num\_active\_experts ([`Int`](/mojo/std/builtin/int/Int)): Number of active experts. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context.
--- ## grouped_1d1d_matmul
CPU entrypoint for grouped 1D-1D block-scaled SM100 matmul. This module provides the public API for launching the grouped 1D-1D matmul kernel for Mixture of Experts (MoE) layers. Usage: grouped\_matmul\_1d1d\_nvfp4\[transpose\_b=True, config=config]\( c\_tensor, # Output: (total\_tokens, N) a\_tensor, # Input A: (total\_tokens, K) a\_offsets, # Per-expert offsets into A a\_scale\_offsets, # Per-expert scale offsets b\_tensor, # Weights B: (num\_experts, N, K) expert\_ids, # Active expert IDs a\_scales, # Scale factors for A b\_scales, # Scale factors for B expert\_scales, # Per-expert output scaling num\_active\_experts, ctx, ) ## Functions * [​`grouped_matmul_1d1d_nvfp4`](./grouped_matmul_1d1d_nvfp4): Launch grouped 1D-1D block-scaled matmul kernel. * [​`grouped_matmul_dynamic_scaled_nvfp4`](./grouped_matmul_dynamic_scaled_nvfp4): Performs grouped matrix multiplication with NVFP4 quantization.
--- ## Grouped1D1DMatmulKernel
`struct Grouped1D1DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, a_scale_offsets_layout: Layout, c_device_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]` Grouped 1D-1D block-scaled matmul kernel. Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.a_smem_layout.size() * size_of[a_type]())` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[1].value()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * 128)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = DType.float32` ### `b_expected_bytes` `comptime b_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.b_smem_layout.size() * size_of[b_type]())` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[1].value()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)` ### `cta_group` `comptime cta_group = config.cta_group` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]` ### `input_expected_bytes` `comptime input_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)` ### `InputTilePipelineType` `comptime InputTilePipelineType = InputTilePipeline[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]` ### `MmaOp` `comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `NUM_THREADS` `comptime NUM_THREADS = WarpRole1D1D.TOTAL_THREADS` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputPipeline` `comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `sfa_expected_bytes` `comptime sfa_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfa_smem_layout.size() * size_of[sfa_dtype]())` ### `SFA_NUM_COLS` `comptime SFA_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))` ### `sfb_expected_bytes` `comptime sfb_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfb_smem_layout.size() * size_of[sfb_dtype]())` ### `SFB_NUM_COLS` `comptime SFB_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))` ### `SmemType` `comptime SmemType = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N` ### `TilePayload` `comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]` ### `TileWriterType` `comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4]` ### `Tmem` `comptime Tmem = TmemAllocation[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `TmemDealloc` `comptime TmemDealloc = TmemDeallocBarrier[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ### `TmemRegion` `comptime TmemRegion = BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]` ### `WorkIterator` `comptime WorkIterator = GroupedWorkIterator1D1D[offsets_layout, expert_ids_layout, expert_scales_layout, static_N, config.block_tile_shape, config.cluster_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]` ## Methods ### `validate_config` `static validate_config()` Compile-time validation of kernel configuration. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], a_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32)` Grouped 1D-1D block-scaled GEMM kernel entry point. Uses grid-constant TMAs with offset-based addressing for 1D-1D layout. ### `load_input_tiles` `static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_ctx: GroupedWorkContext1D1D, a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], iter_idx: UInt32, elect_one_cta: Bool)` Load A, B, SFA, SFB tiles using TMA. ### `mma` `static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)` Execute MMA operations. ### `epilogue` `static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], stage: OutputStage[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_ctx: GroupedWorkContext1D1D)` Execute epilogue to store accumulated results with expert\_scale.
--- ## WarpRole1D1D
`@register_passable(trivial)` `struct WarpRole1D1D` Warp role for 1D-1D kernel with 3-warp specialization. Thread layout (192 threads total) - matches original kernel: * Warps 0-3 (threads 0-127): Epilogue (4 warps) * Warp 4 (threads 128-159): TMA Load * Warp 5 (threads 160-191): MMA This layout matches the original grouped\_matmul\_sm100\_1d1d.mojo kernel which uses WarpRole\[has\_scheduler=False]. The epilogue warps being at 0-3 is important because TMAStoreCoords uses `warp_id == 0` for election. No scheduler warp - work distribution uses linear grid traversal. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `EPILOGUE_WARP_START` `comptime EPILOGUE_WARP_START = 0` ### `LOAD_WARP_START` `comptime LOAD_WARP_START = 128` ### `MMA_WARP_START` `comptime MMA_WARP_START = 160` ### `NUM_EPILOGUE_THREADS` `comptime NUM_EPILOGUE_THREADS = 128` ### `NUM_LOAD_THREADS` `comptime NUM_LOAD_THREADS = 32` ### `NUM_MMA_THREADS` `comptime NUM_MMA_THREADS = 32` ### `TOTAL_THREADS` `comptime TOTAL_THREADS = 192` ## Methods ### `is_epilogue` `static is_epilogue() -> Bool` Returns True if current thread is in an epilogue warp (warps 0-3). **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_load` `static is_load() -> Bool` Returns True if current thread is in the TMA load warp (warp 4). **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` Returns True if current thread is in the MMA warp (warp 5). **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## grouped_1d1d_matmul_kernel
Grouped 1D-1D block-scaled SM100 matmul kernel. This kernel implements grouped GEMM for Mixture of Experts (MoE) layers using the 1D-1D tensor layout with offset-based addressing. Key characteristics: * 3-warp specialization (Load, MMA, Epilogue) - no scheduler warp * Grid-constant TMA descriptors (no runtime tensormap updates) * Offset-based addressing via a\_offsets for contiguous token buffers * Per-expert output scaling via expert\_scales tensor Architecture: * TMA warp: Loads A, B, SFA, SFB tiles using grid-constant TMAs * MMA warp: Executes block-scaled matrix multiply * Epilogue warps: Stores results with expert\_scale applied This is a port of grouped\_matmul\_sm100\_1d1d.mojo to the structured kernels architecture. ## Structs * [​`Grouped1D1DMatmulKernel`](./Grouped1D1DMatmulKernel): Grouped 1D-1D block-scaled matmul kernel. * [​`WarpRole1D1D`](./WarpRole1D1D): Warp role for 1D-1D kernel with 3-warp specialization.
--- ## Grouped1D1DSmem
`struct Grouped1D1DSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]` SMEM struct for grouped 1D-1D block-scaled GEMM. Simplified version of GroupedBlockScaledSmem for offset-based addressing. Uses 3-warp specialization (Load, MMA, Epilogue) without a scheduler warp, so CLC pipeline storage is not needed. Layout in SMEM: 1. A tiles (input pipeline stages) 2. B tiles (input pipeline stages) 3. C tiles (output stages) 4. SFA tiles (scaling factors for A) 5. SFB tiles (scaling factors for B) 6. Input pipeline barriers 7. Output pipeline barriers (accum barriers) 8. TMEM deallocation state ## Fields * ​tiles (`Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles`): * ​pipelines (`Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout` ### `ATileArray` `comptime ATileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray` ### `b_smem_layout` `comptime b_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTileArray` `comptime BTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray` ### `c_smem_layout` `comptime c_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout` ### `CTileArray` `comptime CTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray` ### `Layouts` `comptime Layouts = SmemLayouts[a_type, b_type, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)` ### `num_output_stages` `comptime num_output_stages = config.num_output_stages` ### `num_pipeline_stages` `comptime num_pipeline_stages = config.num_pipeline_stages` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `Pipelines` `comptime Pipelines = SmemPipelineBundleNoClc[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]` ### `SF_BK` `comptime SF_BK = sf_bk[config]()` ### `SF_K_GROUP_SIZE` `comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()` ### `SFA_DIM0` `comptime SFA_DIM0 = sfa_dim0[config]()` ### `SFA_DIM1` `comptime SFA_DIM1 = sfa_dim1[config]()` ### `sfa_smem_layout` `comptime sfa_smem_layout = tile_sf_layout_k_major[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFATileArray` `comptime SFATileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray` ### `SFB_DIM0` `comptime SFB_DIM0 = sfb_dim0[config]()` ### `SFB_DIM1` `comptime SFB_DIM1 = sfb_dim1[config]()` ### `sfb_smem_layout` `comptime sfb_smem_layout = tile_sf_layout_k_major[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()` ### `SFBTileArray` `comptime SFBTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray` ### `Tiles` `comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `Grouped1D1DSmem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `Grouped1D1DSmem` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray` Get C tile array accessor (TileTensor-based). **Returns:** `Grouped1D1DSmem` ### `sfa_tiles` `sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray` Get SFA tile array accessor (TileTensor-based). **Returns:** `Grouped1D1DSmem` ### `sfb_tiles` `sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray` Get SFB tile array accessor (TileTensor-based). **Returns:** `Grouped1D1DSmem` ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Total size of A+B tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `sf_pipeline_size` `static sf_pipeline_size() -> Int` Total size of SFA+SFB tiles for all pipeline stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Size of C tiles for all output stages (in elements). **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Total tile storage size (A+B+SFA+SFB+C) in elements. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## grouped_1d1d_smem
Shared memory layout for grouped 1D-1D block-scaled SM100 matmul. This is a simplified SMEM structure for the 1D-1D kernel variant that uses offset-based addressing instead of pointer-per-group. Key differences from the standard GroupedBlockScaledSmem: 1. No tensormap descriptors - TMAs are grid-constant (not updated per-group) 2. No CLC pipeline storage - uses 3-warp specialization (no scheduler warp) 3. Simpler barrier structure optimized for the 1D-1D workload The 1D-1D layout uses: * A tensor: Contiguous (total\_tokens, K) with a\_offsets for per-group access * B tensor: Batched (num\_experts, N, K) weights * C tensor: Contiguous (total\_tokens, N) output ## Structs * [​`Grouped1D1DSmem`](./Grouped1D1DSmem): SMEM struct for grouped 1D-1D block-scaled GEMM.
--- ## GroupedWorkContext1D1D
`struct GroupedWorkContext1D1D` Context for current work tile, used with context manager pattern. Provides access to work tile info and expert scale factor. ## Fields * ​info (`GroupedWorkInfo1D1D`): * ​expert\_scale (`Float32`): * ​m\_end (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `m` `m(self) -> UInt32` M coordinate in contiguous token space. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `n` `n(self) -> UInt32` N coordinate in output space. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `group_idx` `group_idx(self) -> UInt32` Index into active experts list. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `expert_id` `expert_id(self) -> Int32` Expert ID for B tensor indexing. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32) ### `is_valid` `is_valid(self) -> Bool` Whether this tile has valid work. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## GroupedWorkInfo1D1D
`@register_passable(trivial)` `struct GroupedWorkInfo1D1D` Work tile information for 1D-1D grouped matmul. Contains the coordinates and metadata for a single work tile: * m, n: Output tile coordinates (m is in contiguous token space) * group\_idx: Index into active experts (for a\_offsets indexing) * expert\_id: The actual expert ID for B tensor lookup * is\_valid\_tile: Whether this tile contains valid work * terminate: Whether the scheduler has no more work ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​group\_idx (`UInt32`): * ​expert\_id (`Int32`): * ​is\_valid\_tile (`Bool`): * ​terminate (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `is_valid` `is_valid(self) -> Bool` Returns True if this work tile has valid work to do. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_done` `is_done(self) -> Bool` Returns True if the scheduler has no more work. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## GroupedWorkIterator1D1D
`struct GroupedWorkIterator1D1D[offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, static_N: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index(1, 1, 1), cta_group: Int = 1, swizzle: Bool = False]` Work iterator for 1D-1D grouped block-scaled matmul. Iterates through work tiles using offset-based addressing: * a\_offsets: Prefix sum of token counts per active expert * expert\_ids: Mapping from active expert index to actual expert ID * expert\_scales: Per-expert output scaling factors Usage: var work\_iter = GroupedWorkIterator1D1D\[...]\( num\_active\_experts, a\_offsets, expert\_ids, expert\_scales ) while True: var ctx = work\_iter.next() if ctx.info.is\_done(): break if ctx.info.is\_valid(): \# Process tile at (ctx.m(), ctx.n()) for expert ctx.expert\_id() \# Apply scaling with ctx.expert\_scale ## Fields * ​num\_active\_experts (`Int`): * ​group\_offsets (`LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin]`): * ​expert\_ids (`LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin]`): * ​expert\_scales (`LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin]`): * ​current\_iter (`Int32`): * ​current\_group\_idx (`UInt32`): * ​current\_dynamic\_dim\_cumsum (`UInt32`): * ​block\_idx\_start (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `cta_group_tile_shape` `comptime cta_group_tile_shape = Index((tile_shape.__getitem__[3, DType.int64, Int](0) * cta_group), (tile_shape.__getitem__[3, DType.int64, Int](1) * cta_group))` ### `div_dynamic_block` `comptime div_dynamic_block = FastDiv[DType.uint32](GroupedWorkIterator1D1D[offsets_layout, expert_ids_layout, expert_scales_layout, static_N, tile_shape, cluster, cta_group, swizzle].cta_group_tile_shape.__getitem__[2, DType.int64, Int](0))` ### `kNum1DBlocksPerGroup` `comptime kNum1DBlocksPerGroup = 16` ### `num_static_dim_blocks` `comptime num_static_dim_blocks = SIMD[DType.uint32, 1](ceildiv(static_N, tile_shape.__getitem__[3, DType.int64, Int](1)))` ## Methods ### `__init__` `__init__(out self, num_active_experts: Int, group_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin])` ### `next` `next(mut self) -> GroupedWorkContext1D1D` Fetch next work tile and return context with work info and scale. **Returns:** [`GroupedWorkContext1D1D`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/grouped_block_scaled_1d1d/grouped_1d1d_tile_scheduler/GroupedWorkContext1D1D) ### `current_expert_id` `current_expert_id(self) -> Int32` Get the expert ID for the current group. **Returns:** [`Int32`](/mojo/std/builtin/simd/#int32)
--- ## grouped_1d1d_tile_scheduler
Work scheduler for grouped 1D-1D block-scaled SM100 matmul. Provides work iteration using offset-based addressing for the 1D-1D tensor layout. This is a port of the TileScheduler from grouped\_matmul\_tile\_scheduler.mojo to the structured kernels architecture with context manager patterns. Key characteristics: * Uses a\_offsets tensor for group boundaries (prefix sum of token counts) * Each iteration returns (m\_coord, n\_coord, expert\_id, expert\_scale) * Supports block swizzling for L2 cache efficiency * 3-warp specialization (no scheduler warp) ## Structs * [​`GroupedWorkContext1D1D`](./GroupedWorkContext1D1D): Context for current work tile, used with context manager pattern. * [​`GroupedWorkInfo1D1D`](./GroupedWorkInfo1D1D): Work tile information for 1D-1D grouped matmul. * [​`GroupedWorkIterator1D1D`](./GroupedWorkIterator1D1D): Work iterator for 1D-1D grouped block-scaled matmul.
--- ## grouped_block_scaled_1d1d
Grouped block-scaled matmul with 1D-1D tensor layout for SM100. This module provides a structured kernel implementation for grouped GEMM operations in Mixture of Experts (MoE) layers, using contiguous token buffers with offset-based addressing (the "1D-1D" layout). Key characteristics: * A tensor: Contiguous (total\_tokens, K) with a\_offsets for per-group access * B tensor: Batched (num\_experts, N, K) weights * C tensor: Contiguous (total\_tokens, N) output * Per-expert output scaling via expert\_scales tensor This is a port of `max/kernels/src/linalg/grouped_matmul_sm100_1d1d.mojo` to the structured kernels architecture. See PORTING\_PLAN.md for implementation details. ## Modules * [​`grouped_1d1d_matmul`](./grouped_1d1d_matmul/): CPU entrypoint for grouped 1D-1D block-scaled SM100 matmul. * [​`grouped_1d1d_matmul_kernel`](./grouped_1d1d_matmul_kernel/): Grouped 1D-1D block-scaled SM100 matmul kernel. * [​`grouped_1d1d_smem`](./grouped_1d1d_smem/): Shared memory layout for grouped 1D-1D block-scaled SM100 matmul. * [​`grouped_1d1d_tile_scheduler`](./grouped_1d1d_tile_scheduler/): Work scheduler for grouped 1D-1D block-scaled SM100 matmul.
--- ## sm100_structured
SM100 Structured Kernels - Blackwell matmul implementation. ## Packages * [​`block_scaled`](./block_scaled/): Block-scaled matmul kernel for SM100. * [​`blockwise_fp8`](./blockwise_fp8/): Blockwise FP8 matmul kernel for SM100. * [​`blockwise_fp8_1d2d`](./blockwise_fp8_1d2d/): Blockwise FP8 1D2D grouped matmul kernel for SM100. * [​`default`](./default/): Default SM100 matmul kernel - Standard FP8/BF16 warp-specialized implementation. * [​`grouped_block_scaled`](./grouped_block_scaled/): Grouped block-scaled matmul kernel for SM100. * [​`grouped_block_scaled_1d1d`](./grouped_block_scaled_1d1d/): Grouped block-scaled matmul with 1D-1D tensor layout for SM100. * [​`structured_kernels`](./structured_kernels/): Shared library components for SM100 structured kernels.
--- ## SmemBarriers
`struct SmemBarriers[num_group_pipeline_stages: Int, num_accum_pipeline_stages: Int, num_clc_pipeline_stages: Int]` Composable barrier storage for SM100 matmul SMEM structs. This struct consolidates all barrier-related storage and accessors, enabling code reuse across MatmulSmem, BlockScaledSmem, and BlockwiseFP8Smem through composition. Usage: Compose this struct into SMEM structs and delegate accessors: ```` ``` struct MySmem[...]: var barriers: SmemBarriers[num_group, num_accum, num_clc] fn input_barriers(ref[AddressSpace.SHARED] self): return self.barriers.input_barriers() ``` ```` ## Parameters * ​num\_group\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of K-group pipeline stages. * ​num\_accum\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​num\_clc\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of CLC pipeline stages. ## Fields * ​input\_barriers\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].InputBarriers.Storage`): * ​accum\_barriers\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].AccumBarriers.Storage`): * ​clc\_full\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcBarriers.Storage`): * ​clc\_empty\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcBarriers.Storage`): * ​clc\_throttle\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcThrottleBarriers.Storage`): * ​clc\_response\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcResponse.Storage`): * ​tmem\_dealloc\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].TmemDealloc.Storage`): * ​tmem\_addr\_storage (`SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].TmemAddr.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `AccumBarriers` `comptime AccumBarriers = SMemArray[SharedMemBarrier, (num_accum_pipeline_stages * 2)]` ### `ClcBarriers` `comptime ClcBarriers = SMemArray[SharedMemBarrier, num_clc_pipeline_stages]` ### `ClcResponse` `comptime ClcResponse = SMemArray[UInt128, num_clc_pipeline_stages]` ### `ClcThrottleBarriers` `comptime ClcThrottleBarriers = SMemArray[SharedMemBarrier, (num_clc_pipeline_stages * 2)]` ### `InputBarriers` `comptime InputBarriers = SMemArray[SharedMemBarrier, (num_group_pipeline_stages * 2)]` ### `TmemAddr` `comptime TmemAddr = SMemArray[UInt32, 1]` ### `TmemDealloc` `comptime TmemDealloc = SMemArray[SharedMemBarrier, 1]` ## Methods ### `input_barriers` `input_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].InputBarriers` Returns input tile pipeline barriers (2 per group stage). **Returns:** `SmemBarriers` ### `accum_barriers` `accum_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].AccumBarriers` Returns accumulator pipeline barriers (2 per accum stage). **Returns:** `SmemBarriers` ### `clc_full` `clc_full(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcBarriers` Returns CLC full barriers (1 per CLC stage). **Returns:** `SmemBarriers` ### `clc_empty` `clc_empty(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcBarriers` Returns CLC empty barriers (1 per CLC stage). **Returns:** `SmemBarriers` ### `clc_throttle` `clc_throttle(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcThrottleBarriers` Returns CLC throttle barriers (2 per CLC stage). **Returns:** `SmemBarriers` ### `clc_response` `clc_response(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].ClcResponse` Returns CLC response storage (1 UInt128 per CLC stage). **Returns:** `SmemBarriers` ### `tmem_dealloc` `tmem_dealloc(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].TmemDealloc` Returns TMEM deallocation barrier. **Returns:** `SmemBarriers` ### `tmem_addr` `tmem_addr(ref[AddressSpace._value._mlir_value] self) -> SmemBarriers[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages].TmemAddr` Returns TMEM address storage. **Returns:** `SmemBarriers`
--- ## TmemDeallocBarrier
`@register_passable(trivial)` `struct TmemDeallocBarrier[cta_group: Int]` TMEM deallocation synchronization barrier. Handles cluster-aware synchronization patterns for TMEM deallocation, supporting both single-CTA and multi-CTA (cta\_group=2) configurations. ## Fields * ​barrier (`TmemDeallocBarrier[cta_group].BarrierStorage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BarrierStorage` `comptime BarrierStorage = SMemArray[SharedMemBarrier, 1]` ## Methods ### `__init__` `__init__(barrier: SMemArray[SharedMemBarrier, 1]) -> Self` Initialize with shared memory barrier array. ### `signal_peer` `signal_peer(self)` Signal peer CTA in cluster (cta\_group=2 only). ### `signal_self` `signal_self(self)` Signal own arrival at barrier. ### `wait` `wait(self)` Wait for barrier completion. ### `complete_dealloc` `complete_dealloc[max_cols: Int = 512](self, tmem: TmemAllocation[cta_group, max_cols])` Complete TMEM deallocation sequence (MMA warp side). Releases the allocation lock, waits for epilogue completion, then deallocates the TMEM. ### `signal_complete` `signal_complete(self)` Signal TMEM consumption complete (Epilogue warp side). For cta\_group=2, signals peer CTA first, then signals self.
--- ## WarpGroupBarrier
`@register_passable(trivial)` `struct WarpGroupBarrier[num_threads: Int, barrier_id: Int = 0]` Named barrier for warp group synchronization. Wraps `named_barrier` and `named_barrier_arrive` with compile-time thread count and barrier ID for type-safe synchronization. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `arrive` `static arrive()` Signal arrival without blocking (non-blocking arrive). ### `wait` `static wait()` Block until all threads have arrived. ### `sync` `static sync()` Full barrier: arrive and wait for all threads.
--- ## barriers
Barrier abstractions for SM100 structured matmul kernels. This module provides type-safe wrappers around low-level barrier primitives, improving code readability and reducing error potential. ## Structs * [​`SmemBarriers`](./SmemBarriers): Composable barrier storage for SM100 matmul SMEM structs. * [​`TmemDeallocBarrier`](./TmemDeallocBarrier): TMEM deallocation synchronization barrier. * [​`WarpGroupBarrier`](./WarpGroupBarrier): Named barrier for warp group synchronization.
--- ## BlockScaledMatmulConfig
`@register_passable(trivial)` `struct BlockScaledMatmulConfig[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool = True]` Static configuration of GPU matmul. ## Fields * ​cta\_group (`Int`): * ​mma\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​AB\_swapped (`Bool`): * ​block\_swizzle\_size (`Int`): * ​raster\_order (`RasterOrder`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_split\_k (`Int`): * ​num\_pipeline\_stages (`Int`): * ​num\_clc\_pipeline\_stages (`Int`): * ​num\_accum\_pipeline\_stages (`Int`): * ​num\_output\_stages (`Int`): * ​output\_tile\_shape (`IndexList[2]`): * ​a\_swizzle (`TensorMapSwizzle`): * ​b\_swizzle (`TensorMapSwizzle`): * ​c\_swizzle (`TensorMapSwizzle`): * ​k\_group\_size (`Int`): * ​scaling\_kind (`UMMAKind`): * ​vec\_sf\_size (`Int`): * ​num\_sf\_k\_tiles (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`Hashable`](/mojo/std/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `sf_block_atom_size` `comptime sf_block_atom_size = (((load_from_mem SF_ATOM_M.__getitem__[Int, Int, 0]()) * (load_from_mem SF_ATOM_M.__getitem__[Int, Int, 1]())) * 4)` ## Methods ### `__init__` `__init__(*, scaling_kind: UMMAKind, cta_group: Int = 2, mma_shape: IndexList[3] = get_mma_shape[a_type, BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(2, 1, 1), AB_swapped: Bool = False, num_split_k: Int = 1, block_swizzle_size: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, k_group_size: Int = 1, num_pipeline_stages: Optional[Int] = None, num_accum_pipeline_stages: Int = 2, num_clc_pipeline_stages: Int = 2) -> Self` ### `swap_AB_type` `swap_AB_type(self) -> BlockScaledMatmulConfig[b_type, a_type, c_type, sfb_dtype, sfa_dtype, transpose_b]` **Returns:** [`BlockScaledMatmulConfig`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/config/BlockScaledMatmulConfig) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to[W: Writer](self, mut writer: W)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String)
--- ## MatmulConfig
`@register_passable(trivial)` `struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True]` Static configuration of GPU matmul. ## Fields * ​cta\_group (`Int`): * ​mma\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​AB\_swapped (`Bool`): * ​block\_swizzle\_size (`Int`): * ​raster\_order (`RasterOrder`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_split\_k (`Int`): * ​num\_pipeline\_stages (`Int`): * ​num\_clc\_pipeline\_stages (`Int`): * ​num\_accum\_pipeline\_stages (`Int`): * ​num\_output\_stages (`Int`): * ​output\_tile\_shape (`IndexList[2]`): * ​a\_swizzle (`TensorMapSwizzle`): * ​b\_swizzle (`TensorMapSwizzle`): * ​c\_swizzle (`TensorMapSwizzle`): * ​k\_group\_size (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`Hashable`](/mojo/std/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ## Methods ### `__init__` `__init__(*, cta_group: Int = 2, mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(2, 1, 1), AB_swapped: Bool = False, num_split_k: Int = 1, block_swizzle_size: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, k_group_size: Int = 1, num_pipeline_stages: Optional[Int] = None, num_accum_pipeline_stages: Int = 2, num_clc_pipeline_stages: Int = 2, extra_smem_per_stage: Int = 0) -> Self` ### `swap_AB_type` `swap_AB_type(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to[W: Writer](self, mut writer: W)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String)
--- ## build_configs
`build_configs[a_type: DType, b_type: DType, c_type: DType, N: Int, K: Int, transpose_b: Bool = True]() -> Set[MatmulConfig[a_type, b_type, c_type, transpose_b]]` **Returns:** `Set`
--- ## choose_config
`choose_config[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True](M: Int, N: Int, K: Int) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## config (3)
SM100 matmul configuration types and utilities. This module provides configuration structs for SM100 (Blackwell) GPU matmul operations, including standard matmul and block-scaled matmul variants. ## Structs * [​`BlockScaledMatmulConfig`](./BlockScaledMatmulConfig): Static configuration of GPU matmul. * [​`MatmulConfig`](./MatmulConfig): Static configuration of GPU matmul. ## Functions * [​`build_configs`](./build_configs): * [​`choose_config`](./choose_config):
--- ## AccumBarrier
`@register_passable(trivial)` `struct AccumBarrier[cta_group: Int]` Pipeline barrier helper for single-CTA vs 2-CTA arrival patterns. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `arrive` `static arrive(pipeline: ProducerConsumerPipeline[num_stages], stage: UInt32)` Signal accumulator arrival on pipeline barrier.
--- ## AccumTile
`@register_passable(trivial)` `struct AccumTile[dtype: DType, size: Int]` Upper + lower TMEM fragments (16 rows each) for SM100 output. ## Fields * ​upper (`SIMD[dtype, size]`): * ​lower (`SIMD[dtype, size]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(upper: SIMD[dtype, size], lower: SIMD[dtype, size]) -> Self`
--- ## EpilogueApplier
`@register_passable(trivial)` `struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]` Apply element-wise epilogue lambda to register fragments. ## Fields * ​coords (`EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords`): * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Coords` `comptime Coords = FragmentCoords[stageN, repeats]` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32) -> Self` ### `compute_staged_coords` `compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]` Compute global coords with warp and stage offsets (layout-dependent). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `apply_to_fragment` `apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)` Apply epilogue lambda to fragment elements with global coords. ### `apply_to_both_fragments` `apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]` Apply epilogue to both fragments (main entry point). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `add_residual_to_fragment` `add_residual_to_fragment[epilogue_dtype: DType, frag_size: Int, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut frag: SIMD[epilogue_dtype, frag_size], local_row: UInt32, local_col: UInt32, is_upper: Bool, src_ptr: UnsafePointer[Scalar[c_type], origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype])` Add beta \* C to fragment elements by loading C from swizzled SMEM. Uses the same per-lane coordinate mapping as apply\_to\_fragment, but instead of applying a lambda, loads source C values from SMEM at the matching swizzled addresses and adds beta \* C to each element. **Args:** * ​frag ([`SIMD`](/mojo/std/builtin/simd/SIMD)): Fragment register values to modify in-place. * ​local\_row ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Tile-local row offset (warp offset within tile). * ​local\_col ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Tile-local column offset (stage offset within tile). * ​is\_upper ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether this is the upper (rows 0-15) or lower (16-31) fragment half. * ​src\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to source C SMEM tile (same TMA swizzle as output). * ​beta ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Residual scale factor. ### `add_residual_to_both_fragments` `add_residual_to_both_fragments[epilogue_dtype: DType, frag_size: Int, is_lower_frag_required: Bool, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, src_ptr: UnsafePointer[Scalar[c_type], origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype]) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]` Add beta \* C to both fragment halves from swizzled SMEM. Computes tile-local coordinates from stage and warp ID, then loads source C from SMEM and adds beta \* C to each fragment element. **Args:** * ​upper\_frag ([`SIMD`](/mojo/std/builtin/simd/SIMD)): Upper fragment (rows 0-15 within warp tile). * ​lower\_frag ([`SIMD`](/mojo/std/builtin/simd/SIMD)): Lower fragment (rows 16-31 within warp tile). * ​stage ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Output stage index (for column offset computation). * ​src\_ptr ([`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)): Pointer to source C SMEM tile. * ​beta ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Residual scale factor. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple): Updated (upper\_frag, lower\_frag) tuple.
--- ## EpilogueConfig
`@register_passable(trivial)` `struct EpilogueConfig[MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, transpose_c: Bool]` Computed epilogue parameters based on MMA and CTA configuration. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `bits` `comptime bits = 256` ### `cg1_num_stages` `comptime cg1_num_stages = (MMA_N // stageN)` ### `cg2_num_stages` `comptime cg2_num_stages = (MMA_N // stageN) if (MMA_M == 256)._mlir_value else ((MMA_N // stageN) // 2)` ### `data_paths` `comptime data_paths = 16` ### `fragment_size` `comptime fragment_size = 4` ### `is_lower_frag_required` `comptime is_lower_frag_required = (MMA_M == 64) if (cta_group == 1)._mlir_value else (cta_group == 1).__bool__().__invert__()` ### `num_stages` `comptime num_stages = (MMA_N // stageN) if (eq MMA_M._mlir_value, 256) else ((MMA_N // stageN) // 2) if (cta_group == 2)._mlir_value else EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c].cg1_num_stages`
--- ## FragmentCoords
`@register_passable(trivial)` `struct FragmentCoords[stageN: Int, repeats: Int]` Fragment element coordinates for tcgen05 16x256b matrix layout. ## Fields * ​top\_upper (`StaticTuple[UInt32, 2]`): * ​bottom\_upper (`StaticTuple[UInt32, 2]`): * ​top\_lower (`StaticTuple[UInt32, 2]`): * ​bottom\_lower (`StaticTuple[UInt32, 2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `load_width` `comptime load_width = 2` ### `threads_per_row` `comptime threads_per_row = ((stageN // repeats) // 2)` ## Methods ### `__init__` `__init__(lane_id: UInt32) -> Self` Compute (row, col) for each fragment position from lane ID.
--- ## SMemEpilogueWriter
`@register_passable(trivial)` `struct SMemEpilogueWriter[c_type: DType, num_output_stages: Int, //, c_smem_dim0: Int, c_smem_dim1: Int, epilogue_dtype: DType, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, num_stages: Int, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type]` SMEM-based epilogue: write accumulators and apply lambda in SMEM. ## Fields * ​warp\_id (`UInt32`): * ​c\_tiles (`SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray`): * ​M (`UInt32`): * ​N (`UInt32`): * ​c\_row (`UInt32`): * ​c\_col (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `barrier_threads` `comptime barrier_threads = (num_output_warps * WARP_SIZE)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)` ### `CTileArray` `comptime CTileArray = SMemTileArray[c_type, SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].c_smem_layout, num_output_stages, 128]` ### `data_paths` `comptime data_paths = 16` ### `OutputSyncBarrier` `comptime OutputSyncBarrier = WarpGroupBarrier[SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].barrier_threads]` ### `stage_contiguous_size` `comptime stage_contiguous_size = c_smem_dim1` ### `stageN` `comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1` ### `swizzle` `comptime swizzle = make_swizzle[c_type, c_swizzle]()` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ### `Tile` `comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]` ## Methods ### `__init__` `__init__(warp_id: UInt32, c_tiles: SMemTileArray[c_type, SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].c_smem_layout, num_output_stages, 128], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self` Initialize the SMEM epilogue writer. ### `write_tile` `write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])` Write accumulator tile to SMEM and apply epilogue lambda.
--- ## TMAStoreCoords
`@register_passable(trivial)` `struct TMAStoreCoords[BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, c_smem_shape0: Int, stage: Int, batched: Bool = False]` TMA store coordinates and warp election for SM100 epilogue. When batched=True, includes a batch coordinate for 3D TMA stores. ## Fields * ​coord\_m (`UInt`): * ​coord\_n (`UInt`): * ​coord\_b (`UInt`): * ​elect\_one\_warp (`Bool`): * ​c\_smem\_coord\_m (`UInt`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CG1_TMA_BM` `comptime CG1_TMA_BM = c_smem_shape0` ### `CG2_TMA_BM` `comptime CG2_TMA_BM = c_smem_shape0 if (MMA_M == 256)._mlir_value else BM` ### `stage_n_offset` `comptime stage_n_offset = (stage * stageN)` ### `TMA_BM` `comptime TMA_BM = c_smem_shape0 if (eq MMA_M._mlir_value, 256) else BM if (cta_group == 2)._mlir_value else TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, c_smem_shape0, stage, batched].CG1_TMA_BM` ## Methods ### `__init__` `__init__(c_coord: Tuple[UInt32, UInt32], warp_id: UInt32) -> Self` Compute TMA store coordinates from 2D tile coords and warp ID. `__init__(c_coord: Tuple[UInt32, UInt32, UInt32], warp_id: UInt32) -> Self` Compute TMA store coordinates from 3D tile coords and warp ID.
--- ## TMAStoreExecutor
`@register_passable(trivial)` `struct TMAStoreExecutor[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, stage_contiguous_size: Int, cta_group: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, batched: Bool = False]` Execute TMA store from SMEM to GMEM with proper tiling. Handles 3 paths: transpose+cta\_group2+MMA128, transpose+other, non-transpose. When batched=True, uses 3D coordinates (M, N, Batch) for TMA stores. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)` ### `c_smem_shape0` `comptime c_smem_shape0 = c_smem_dim0` ### `CG1_TMA_BM` `comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0` ### `CG2_TMA_BM` `comptime CG2_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0 if (MMA_M == 256)._mlir_value else BM` ### `num_c_smem_tiles` `comptime num_c_smem_tiles = ((128 // TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].swizzle_width) // 1 if is_lower_frag_required else 2)` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ### `TMA_BM` `comptime TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0 if (eq MMA_M._mlir_value, 256) else BM if (cta_group == 2)._mlir_value else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].CG1_TMA_BM` ## Methods ### `execute` `static execute[c_layout: Layout, c_desc_layout: Layout](c_smem_tile: LayoutTensor[c_type, TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], store_coords: TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0, stage, batched], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], warp_id: UInt32, lane: UInt32)` Execute TMA store with elected warp and lane 0.
--- ## TMEMToSMemWriter
`@register_passable(trivial)` `struct TMEMToSMemWriter[c_type: DType, accum_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]` Write TMEM accumulators to SMEM via st.matrix (SM100-specific). ## Fields * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)` ### `Config` `comptime Config = EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c]` ### `data_paths` `comptime data_paths = 16` ### `stage_contiguous_size` `comptime stage_contiguous_size = c_smem_dim1` ### `swizzle` `comptime swizzle = make_swizzle[c_type, c_swizzle]()` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32) -> Self` ### `write_fragments` `write_fragments[repeat: Int](self, upper_frag: SIMD[c_type, (4 * repeat)], lower_frag: SIMD[c_type, (4 * repeat)], c_smem_tile: LayoutTensor[c_type, TMEMToSMemWriter[c_type, accum_type, c_smem_dim0, c_smem_dim1, BM, BN, MMA_M, MMA_N, stageN, cta_group, num_output_warps, c_swizzle, transpose_c].c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])` Write pre-loaded fragments to SMEM (use after register-based epilogue).
--- ## epilogue_components
Low-level epilogue components for SM100 matrix multiplication. This module provides modular building blocks for the output pipeline: 1. **store\_fragment\_to\_smem**: Register to shared memory via st.matrix instructions 2. **TMEMToSMemWriter**: Write TMEM accumulators to shared memory 3. **TMAStoreExecutor**: Execute TMA stores with proper SMEM tiling 4. **EpilogueApplier**: Apply element-wise operations on fragments The SM100 epilogue pipeline flows as: TMEM (accumulators) → Registers → SMEM → GMEM (via TMA) ## Structs * [​`AccumBarrier`](./AccumBarrier): Pipeline barrier helper for single-CTA vs 2-CTA arrival patterns. * [​`AccumTile`](./AccumTile): Upper + lower TMEM fragments (16 rows each) for SM100 output. * [​`EpilogueApplier`](./EpilogueApplier): Apply element-wise epilogue lambda to register fragments. * [​`EpilogueConfig`](./EpilogueConfig): Computed epilogue parameters based on MMA and CTA configuration. * [​`FragmentCoords`](./FragmentCoords): Fragment element coordinates for tcgen05 16x256b matrix layout. * [​`SMemEpilogueWriter`](./SMemEpilogueWriter): SMEM-based epilogue: write accumulators and apply lambda in SMEM. * [​`TMAStoreCoords`](./TMAStoreCoords): TMA store coordinates and warp election for SM100 epilogue. * [​`TMAStoreExecutor`](./TMAStoreExecutor): Execute TMA store from SMEM to GMEM with proper tiling. * [​`TMEMToSMemWriter`](./TMEMToSMemWriter): Write TMEM accumulators to SMEM via st.matrix (SM100-specific). ## Functions * [​`shared_memory_epilogue`](./shared_memory_epilogue): Apply element-wise epilogue to non-transposed SMEM tile. * [​`shared_memory_epilogue_transpose`](./shared_memory_epilogue_transpose): Apply element-wise epilogue to transposed SMEM tile. * [​`store_fragment_to_smem`](./store_fragment_to_smem): Store fragment to SMEM via st.matrix instruction. * [​`tma_wait_pipelined`](./tma_wait_pipelined): Wait for TMA stores with pipelining.
--- ## shared_memory_epilogue (Epilogue_components)
`shared_memory_epilogue[MMA_M: Scalar[DType.uint], data_paths: Scalar[DType.uint], num_stages: Scalar[DType.uint], stage: Scalar[DType.uint], stageN: Scalar[DType.uint], c_type: DType, shared_n: Scalar[DType.uint], simd_size: Scalar[DType.uint], c_smem_upper_layout: Layout, c_smem_lower_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: Int](M: UInt32, N: UInt32, c_col: Scalar[DType.uint], c_row: Scalar[DType.uint], c_smem_warp_tile_upper: LayoutTensor[c_type, c_smem_upper_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_smem_warp_tile_lower: LayoutTensor[c_type, c_smem_lower_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Apply element-wise epilogue to non-transposed SMEM tile. Each warp processes upper (rows 0-15) and lower (rows 16-31) fragments. Uses distribute layout to map SIMD vectors to threads within each warp.
--- ## shared_memory_epilogue_transpose (Epilogue_components)
`shared_memory_epilogue_transpose[stage: Scalar[DType.uint], stageN: Scalar[DType.uint], c_type: DType, c_smem_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: Int, warp_dim: Int, MMA_M: Int, BN: Int, cta_group: Int](M: UInt32, N: UInt32, c_col: Scalar[DType.uint], c_row: Scalar[DType.uint], c_smem: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_i: Scalar[DType.uint], warp_j: Scalar[DType.uint])` Apply element-wise epilogue to transposed SMEM tile. Supports warp\_dim=1 (stageN, warp\_i, U) or warp\_dim=2 (warp\_j, stageN, warp\_i, UL).
--- ## store_fragment_to_smem
`store_fragment_to_smem[swizzle: Swizzle, stageN: Int, transpose_c: Bool = False, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](vec: SIMD[dtype, size], dst: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_offset: UInt32 = 0)` Store fragment to SMEM via st.matrix instruction.
--- ## tma_wait_pipelined
`tma_wait_pipelined[c_type: DType, c_layout: Layout, c_desc_layout: Layout, is_last_stage: Bool](c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout])` Wait for TMA stores with pipelining. For SM100 output pipeline: * Non-last stages: Keep 1 store in flight for pipelining * Last stage: Wait for all stores to complete
--- ## structured_kernels
Shared library components for SM100 structured kernels. ## Modules * [​`barriers`](./barriers/): Barrier abstractions for SM100 structured matmul kernels. * [​`config`](./config/): SM100 matmul configuration types and utilities. * [​`epilogue_components`](./epilogue_components/): Low-level epilogue components for SM100 matrix multiplication. * [​`kernel_common`](./kernel_common/): Shared kernel components for SM100 warp-specialized matmul kernels. * [​`output_writer`](./output_writer/): TileWriter for SM100 matmul output pipeline. * [​`pipeline`](./pipeline/): Producer-consumer pipeline utilities for SM100 structured kernels. * [​`pipeline_storage`](./pipeline_storage/): Unified Pipeline Storage Framework for SM100 Structured Kernels. * [​`tile_loader`](./tile_loader/): TMA tile loader for SM100 matrix multiplication. * [​`tile_pipeline`](./tile_pipeline/): Tile pipeline for SM100 producer-consumer synchronization. * [​`tile_scheduler`](./tile_scheduler/): * [​`tile_scheduler_splitk`](./tile_scheduler_splitk/): * [​`tile_types`](./tile_types/): Native TileTensor types for SM100 structured kernels. * [​`tmem`](./tmem/): Tensor Memory (TMEM) abstractions for SM100 Blackwell GPUs. * [​`warp_context`](./warp_context/): RAII warp context managers for SM100 matmul kernel.
--- ## KernelContext
`struct KernelContext[num_clc_pipeline_stages: Int, cta_group: Int, CLUSTER_M: Int, CLUSTER_N: Int]` Shared kernel state: election vars, CTA coords, multicast masks, pipeline states. ## Fields * ​elect\_one\_warp (`Bool`): * ​elect\_one\_thread (`Bool`): * ​elect\_one\_cta (`Bool`): * ​is\_first\_cta\_in\_cluster (`Bool`): * ​warp\_id (`UInt32`): * ​rank\_m (`UInt`): * ​rank\_n (`UInt`): * ​peer\_cta\_coord (`Tuple[UInt, UInt, UInt]`): * ​a\_multicast\_mask (`UInt16`): * ​b\_multicast\_mask (`UInt16`): * ​mma\_complete\_mask (`Int`): * ​ptr\_tmem\_addr (`LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `TmemAddrArray` `comptime TmemAddrArray = SMemArray[UInt32, 1]` ## Methods ### `__init__` `__init__(out self, ptr_tmem_addr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED])` Initialize context from TMEM pointer; computes all derived state. `__init__(out self, tmem_addr: SMemArray[UInt32, 1])` Initialize context from typed TMEM address array.
--- ## WarpRole (Kernel_common)
`@register_passable(trivial)` `struct WarpRole` Warp role identifiers for SM100 warp-specialized kernel. Warp assignment (8 warps total = 256 threads): * Epilogue: warp IDs 0-3 (4 warps, 128 threads) * Scheduler: warp ID 4 (1 warp, 32 threads) * MainLoad: warp ID 5 (1 warp, 32 threads) * Mma: warp ID 6 (1 warp, 32 threads) * EpilogueLoad: warp ID 7 (1 warp, 32 threads) - loads source C for residual Note: When epilogue load is not needed (no residual), warp 7 exits early. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole(3)` ### `EpilogueLoad` `comptime EpilogueLoad = WarpRole(7)` ### `MainLoad` `comptime MainLoad = WarpRole(5)` ### `Mma` `comptime Mma = WarpRole(6)` ### `Scheduler` `comptime Scheduler = WarpRole(4)` ## Methods ### `__eq__` `__eq__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: Scalar[DType.uint]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_scheduler` `static is_scheduler() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_epilogue_load` `static is_epilogue_load() -> Bool` Check if current warp is the epilogue load warp (loads source C). **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## consumer_main_loop (Kernel_common)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = IndexList[3, DType.int64](1, 1, 1, Tuple[]()), k_group_size: Int = 1](tmem_addr: Int, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)` DEPRECATED: Legacy MMA consumer loop for external callers. Use TilePipeline with StandardConsumerStage and BlackwellMatmulSM100Kernel.mma() for new code. This function is kept for backward compatibility.
--- ## kernel_common
Shared kernel components for SM100 warp-specialized matmul kernels. This module contains common components used by all SM100 matmul kernel variants: * WarpRole: Warp specialization roles (MMA, Load, Scheduler, Epilogue) * KernelContext: Common kernel state (election vars, CTA coords, masks) * consumer\_main\_loop: Legacy MMA consumer loop (deprecated but kept for compatibility) ## Structs * [​`KernelContext`](./KernelContext): Shared kernel state: election vars, CTA coords, multicast masks, pipeline states. * [​`WarpRole`](./WarpRole): Warp role identifiers for SM100 warp-specialized kernel. ## Functions * [​`consumer_main_loop`](./consumer_main_loop): DEPRECATED: Legacy MMA consumer loop for external callers.
--- ## TileWriter
`@register_passable(trivial)` `struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_layout: Layout, c_desc_layout: Layout, //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, num_accum_pipeline_stages: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, stage_stride_cols: Int, num_output_warps: Int, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, batched: Bool = False]` Output tile writer for SM100 matmul epilogue. Stores pointer to TMA descriptor. SMEM tiles passed per-call. Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig. The stage\_stride\_cols parameter must match the value used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method. ## Fields * ​c\_tma\_op (`TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOpPtr`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_tile_layout` `comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN)` ### `AccumTmemArray` `comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].num_stages, cta_group=cta_group]` ### `bits` `comptime bits = 256` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)` ### `cg1_num_stages` `comptime cg1_num_stages = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN)` ### `cg2_num_stages` `comptime cg2_num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_dim0 if transpose_c else c_smem_dim1) if (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_M == 256)._mlir_value else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN) // 2)` ### `CTileArray` `comptime CTileArray = SMemTileArray[c_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].c_smem_layout, num_output_stages, 128]` ### `CTileArrayTT` `comptime CTileArrayTT = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]` ### `data_paths` `comptime data_paths = 16` ### `epilogue_dtype` `comptime epilogue_dtype = c_type if (a_type == DType.bfloat16)._mlir_value else DType.float32` ### `fragment_size` `comptime fragment_size = (128 // WARP_SIZE)` ### `is_lower_frag_required` `comptime is_lower_frag_required = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].BM == 64) if (cta_group == 1)._mlir_value else (cta_group == 1).__bool__().__invert__()` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `N_dim` `comptime N_dim = 0 if transpose_c else 1` ### `num_stages` `comptime num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_dim0 if transpose_c else c_smem_dim1) if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN) // 2) if (cta_group == 2)._mlir_value else TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].cg1_num_stages` ### `rep` `comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN // 8)` ### `rep_frag_size` `comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].rep * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].fragment_size)` ### `Stage` `comptime Stage = OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group]` ### `stage_contiguous_size` `comptime stage_contiguous_size = c_smem_dim1` ### `stageN` `comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1` ### `TmaOp` `comptime TmaOp = TMATensorTile[c_type, c_layout, c_desc_layout]` ### `TmaOpPtr` `comptime TmaOpPtr = Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOp, tma_origin]` ## Methods ### `__init__` `__init__(c_tma_op: Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOp, tma_origin]) -> Self` Initialize with pointer to TMA descriptor. ### `write` `write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)` Write accumulated results to global memory (2D coords). ### `write_batched` `write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)` Write accumulated results to global memory (3D batched coords). **Args:** * ​c\_tiles ([`SMemTileArray2DRowMajor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2DRowMajor)): TileTensor-based SMEM tile array for C output. * ​stage ([`OutputStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/OutputStage)): OutputStage with pipeline, index, and TMEM handle. * ​tile\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (m\_tile, n\_tile, batch) coordinates. * ​shape ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (M, N) problem dimensions. * ​alpha ([`Float32`](/mojo/std/builtin/simd/#float32)): Tensor scale factor (scalar). ### `write_splitk` `write_splitk[reduction_layout: Layout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: LayoutTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)` Write with split-K reduction. Only last split writes to GMEM. ### `write_absolute_with_bounds_check` `write_absolute_with_bounds_check[c_tensor_layout: Layout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin])` Write with absolute coordinates and bounds checking. For 1D-1D grouped kernels where M coordinate is absolute. ### `write_with_residual` `write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], src_tile: SMemTileArray[c_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].c_smem_layout, num_output_stages, 128], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)` Write with residual: D = lambda(accum) + beta \* C. This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src\_tile before this method is called. Pipeline: 1. Load accum from TMEM to registers 2. Apply epilogue lambda (if present) 3. Load C fragment from source SMEM 4. Compute D = accum + beta \* C 5. Write D to output SMEM and TMA store to GMEM **Args:** * ​out\_tiles ([`SMemTileArray2DRowMajor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2DRowMajor)): Output SMEM tile array (for D output). * ​stage ([`OutputStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/OutputStage)): OutputStage with pipeline, index, and TMEM handle. * ​src\_tile ([`SMemTileArray`](/mojo/kernels/linalg/structuring/SMemTileArray)): Source C SMEM tile array (LayoutTensor-based, from epilogue load warp). Constructed from smem.src\_tiles().ptr. * ​src\_stage\_idx ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Stage index into src\_tile (0 or 1 for double-buffer). * ​beta ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Residual scale factor. * ​tile\_coord ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (m\_tile, n\_tile) coordinates. * ​shape ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): (M, N) problem dimensions. * ​elect\_one\_warp ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether this warp is elected for coordination.
--- ## output_writer
TileWriter for SM100 matmul output pipeline. Writes accumulated results from TMEM → Registers → SMEM → GMEM (via TMA). Usage: var writer = TileWriter[config=..., ...](Pointer\(to=c_tma_op\)) writer.write(smem.c\_tiles(), stage, coord, shape, elect) ## Structs * [​`TileWriter`](./TileWriter): Output tile writer for SM100 matmul epilogue.
--- ## ConsumeContext
`struct ConsumeContext[pipeline_origin: MutOrigin, num_stages: Int]` Context manager for consuming one pipeline stage. * **enter**: Waits for producer to be ready, returns ref to stage * **exit**: Releases the stage (signals consumption + advances) ## Fields * ​pipeline (`Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = _all_trivial_del[_NoneType, ConsumerStage[pipeline_origin, num_stages]]()` ## Methods ### `__init__` `__init__(out self, pipeline: Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin])` ### `__enter__` `__enter__(mut self) -> ref[origin_of(*[0,0]._stage._value)] ConsumerStage[pipeline_origin, num_stages]` Wait for producer and return reference to stage. **Returns:** `ref` ### `__exit__` `__exit__(mut self)` Release the stage (signals consumption + advances).
--- ## ConsumerStage
`struct ConsumerStage[pipeline_origin: MutOrigin, num_stages: Int]` Unified handle for consuming from a pipeline stage. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `pipeline.acquire_consumer()` or context manager 2. Use `index()` for consumption 3. Must call `release()` to signal and advance (compiler-enforced) Two exit paths: * `release()`: Signal consumption complete + advance (normal path) * `release_without_signal()`: Advance only (for explicit signaling) ## Parameters * ​pipeline\_origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages. ## Fields * ​pipeline (`Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, pipeline: Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin], index: UInt32, mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` ### `__moveinit__` `__moveinit__(out self, deinit other: Self)` Move constructor for Optional support. ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `mbar` `mbar(self) -> MbarPtr` Get the barrier for manual signaling. Use this for specialized signaling patterns like umma\_arrive\_leader\_cta. For standard usage, just call release(). **Returns:** `MbarPtr` ### `arrive` `arrive(self)` Manually arrive on the consumer barrier. Use for lane-guarded patterns: if lane\_id() < CLUSTER\_SIZE: stage.arrive() stage^.release\_without\_signal() ### `release` `release(deinit self)` Signal consumption complete and advance to next stage. This is the standard exit path. Equivalent to: arrive() consumer\_step() ### `release_without_signal` `release_without_signal(deinit self)` Advance to next stage WITHOUT signaling. Use when you've already signaled via arrive() or specialized APIs.
--- ## ExplicitConsumeContext
`struct ExplicitConsumeContext[pipeline_origin: MutOrigin, num_stages: Int]` Context manager for consuming with EXPLICIT barrier arrive. Use this when you need lane-guarded or specialized barrier signaling. * **enter**: Waits for producer to be ready, returns ref to stage with mbar * **exit**: Only advances stage counter, does NOT arrive on barrier The caller is responsible for calling arrive via stage.arrive() or stage.mbar(): with pipeline.consume\_explicit() as stage: \# ... do work ... if lane\_id() < CLUSTER\_SIZE: stage.arrive() \# **exit** only calls consumer\_step(), not arrive() ## Fields * ​pipeline (`Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = _all_trivial_del[_NoneType, ConsumerStage[pipeline_origin, num_stages]]()` ## Methods ### `__init__` `__init__(out self, pipeline: Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin])` ### `__enter__` `__enter__(mut self) -> ref[origin_of(*[0,0]._stage._value)] ConsumerStage[pipeline_origin, num_stages]` Wait for producer and return reference to stage with barrier access. **Returns:** `ref` ### `__exit__` `__exit__(mut self)` Advance to next stage WITHOUT signaling barrier.
--- ## ProduceContext
`struct ProduceContext[pipeline_origin: MutOrigin, num_stages: Int]` Context manager for producing one pipeline stage. * **enter**: Waits for consumer to be ready, returns ref to stage * **exit**: Releases the stage (advances producer) Note: The actual production signal (mma\_arrive) is kernel-specific and must be called by the user before exiting the context. ## Fields * ​pipeline (`Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = _all_trivial_del[_NoneType, ProducerStage[pipeline_origin, num_stages]]()` ## Methods ### `__init__` `__init__(out self, pipeline: Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin])` ### `__enter__` `__enter__(mut self) -> ref[origin_of(*[0,0]._stage._value)] ProducerStage[pipeline_origin, num_stages]` Wait for consumer and return reference to stage. **Returns:** `ref` ### `__exit__` `__exit__(mut self)` Release the stage (advances producer).
--- ## ProducerConsumerPipeline
`@register_passable(trivial)` `struct ProducerConsumerPipeline[num_stages: Int]` A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps). This struct is commonly used with warp specialization to pipeline operations between two warps/warpgroups with data dependencies. ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): The number of pipeline stages. ## Fields * ​full (`MbarPtr`): * ​empty (`MbarPtr`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` Initialize the producer-consumer pipeline with default phases. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barriers. ### `wait_producer` `wait_producer(self)` Consumer waits for producer. ### `wait_consumer` `wait_consumer(self)` Producer waits for consumer. ### `try_wait_producer` `try_wait_producer(self) -> Bool` Non-blocking check if producer data is ready. Note: Use this with wait\_producer\_if\_needed() for the try-acquire pattern: ``` var ready = pipeline.try_wait_producer() # ... do other work ... pipeline.wait_producer_if_needed(ready) ``` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the producer has filled the current stage, False otherwise. ### `try_wait_consumer` `try_wait_consumer(self) -> Bool` Non-blocking check if consumer has freed the stage. Note: Use this with wait\_consumer\_if\_needed() for the try-acquire pattern. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the consumer has freed the current stage, False otherwise. ### `wait_producer_if_needed` `wait_producer_if_needed(self, already_ready: Bool)` Conditionally wait for producer if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_wait\_producer(). If True, skips waiting. ### `wait_consumer_if_needed` `wait_consumer_if_needed(self, already_ready: Bool)` Conditionally wait for consumer if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_wait\_consumer(). If True, skips waiting. ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` Get the producer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the producer signals. ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` Get the consumer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/std/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the consumer signals. ### `producer_stage` `producer_stage(self) -> UInt32` Get the current producer stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32): The current stage index for the producer (0 to num\_stages-1). ### `consumer_stage` `consumer_stage(self) -> UInt32` Get the current consumer stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32): The current stage index for the consumer (0 to num\_stages-1). ### `consumer_step` `consumer_step(mut self)` Advance the consumer to the next pipeline stage. Increments the consumer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. Only switch phase at end of pipeline because we assume all barriers are at the same consumer/producer phase before checked. Once checked, the execution moves to next barrier. ### `producer_step` `producer_step(mut self)` Advance the producer to the next pipeline stage. Increments the producer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. ### `smem_bytes` `static smem_bytes() -> UInt32` Calculate the shared memory bytes required for pipeline barriers. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32): The total number of bytes needed for all pipeline barriers (2 \* num\_stages barriers). ### `init_mbars` `init_mbars(self, producer_arrive_count: Int32, consumer_arrive_count: Int32)` Initialize the smem barriers for the producer and consumer. This function must be called by a single thread and must be called before any the pipeline object is used. **Args:** * ​producer\_arrive\_count ([`Int32`](/mojo/std/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as produced. * ​consumer\_arrive\_count ([`Int32`](/mojo/std/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as consumed. ### `producer_signal_and_step` `producer_signal_and_step(mut self)` Wait for consumer, signal production, and advance stage. Combined operation for CLC throttling (Load warp): 1. Wait for consumer to finish with current stage 2. Signal that producer has new data 3. Advance to next stage ### `consumer_signal_and_step` `consumer_signal_and_step(mut self)` Wait for producer, signal consumption, and advance stage. Combined operation for CLC throttling (Scheduler warp): 1. Wait for producer to have data ready 2. Signal that consumer has consumed data 3. Advance to next stage ### `produce` `produce[origin: MutOrigin, //](ref[origin] self) -> ProduceContext[origin, num_stages]` Produce one pipeline stage with encapsulated barriers. Usage: with pipeline.produce() as stage: \# stage.index() gives current stage \# stage.mbar() gives barrier for signaling \# **exit** calls producer\_step() **Returns:** `ProduceContext`: Context that waits for consumer on enter, advances on exit. ### `consume` `consume[origin: MutOrigin, //](ref[origin] self) -> ConsumeContext[origin, num_stages]` Consume one pipeline stage with encapsulated barriers. Usage: with pipeline.consume() as stage: \# stage.index() gives current stage \# **exit** signals consumer done and advances **Returns:** `ConsumeContext`: Context that waits for producer on enter, signals+advances on exit. ### `consume_explicit` `consume_explicit[origin: MutOrigin, //](ref[origin] self) -> ExplicitConsumeContext[origin, num_stages]` Consume one pipeline stage with EXPLICIT barrier arrive. Use this for kernels requiring lane-guarded or specialized signaling. Usage: with pipeline.consume\_explicit() as stage: \# ... do work ... if lane\_id() < CLUSTER\_SIZE: stage.arrive() # Lane-guarded arrive \# **exit** only advances, does NOT arrive For specialized signaling (e.g., umma\_arrive\_leader\_cta): with pipeline.consume\_explicit() as stage: if cta\_group == 1: stage.arrive() else: umma\_arrive\_leader\_cta(stage.mbar()) **Returns:** `ExplicitConsumeContext`: Context that waits for producer on enter, advances only on exit. ### `acquire_producer` `acquire_producer[origin: MutOrigin, //](ref[origin] self) -> ProducerStage[origin, num_stages]` Acquire a producer stage handle using linear types. Waits for the consumer to free the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var stage = pipeline.acquire\_producer() \# ... produce data, signal via stage.mbar() ... stage^.release() # Advances to next stage **Returns:** `ProducerStage`: A ProducerStage handle that must be released. ### `acquire_consumer` `acquire_consumer[origin: MutOrigin, //](ref[origin] self) -> ConsumerStage[origin, num_stages]` Acquire a consumer stage handle using linear types. Waits for the producer to fill the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var stage = pipeline.acquire\_consumer() \# ... consume data ... stage^.release() # Signals complete and advances For explicit signaling: var stage = pipeline.acquire\_consumer() \# ... consume data ... if lane\_id() < CLUSTER\_SIZE: stage.arrive() stage^.release\_without\_signal() **Returns:** `ConsumerStage`: A ConsumerStage handle that must be released.
--- ## ProducerStage
`struct ProducerStage[pipeline_origin: MutOrigin, num_stages: Int]` Unified handle for producing to a pipeline stage. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `pipeline.acquire_producer()` or context manager 2. Use `index()` and `mbar()` for production 3. Must call `release()` to advance stage (compiler-enforced) ## Parameters * ​pipeline\_origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages. ## Fields * ​pipeline (`Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, pipeline: Pointer[ProducerConsumerPipeline[num_stages], pipeline_origin], index: UInt32, mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` ### `__moveinit__` `__moveinit__(out self, deinit other: Self)` Move constructor for Optional support. ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `mbar` `mbar(self) -> MbarPtr` Get the barrier to signal when production is complete. Caller is responsible for signaling via mma\_arrive or similar. **Returns:** `MbarPtr` ### `release` `release(deinit self)` Advance producer to next stage. This is the only way to destroy this linear type. The compiler will error if you don't call this.
--- ## pipeline (3)
Producer-consumer pipeline utilities for SM100 structured kernels. This module provides pipeline synchronization primitives for warp-specialized GPU kernels, enabling efficient producer-consumer patterns between warps. Key abstraction: * ProducerConsumerPipeline: Low-level barrier management for N-stage pipelines * ProducerStage / ConsumerStage: Unified stage handles (linear types) ## Unified Stage Types ProducerStage and ConsumerStage are linear types (`@explicit_destroy`) that work in both contexts: 1. **Linear Type API** (flat, explicit): var stage = pipeline.acquire\_producer() # ... use stage.index(), stage.mbar() ... stage^.release() # Compiler enforces this call 2. **Context Manager API** (scoped, automatic): with pipeline.produce() as stage: \# ... use stage.index(), stage.mbar() ... # release() called automatically The context managers store the stage internally and return a `ref` to it, allowing access to the full stage API while managing lifetime automatically. ## API Examples Producer side (e.g., MMA warp producing to epilogue): ``` # Context manager: with pipeline.produce() as stage: mma_op.mma(a, b, tmem_offset) mma_op.commit(stage.mbar()) # __exit__ calls stage^.release() -> producer_step() # Linear type: var stage = pipeline.acquire_producer() mma_op.mma(a, b, tmem_offset) mma_op.commit(stage.mbar()) stage^.release() ``` Consumer side (e.g., epilogue consuming from MMA): ``` # Context manager: with pipeline.consume() as stage: process(stage.index()) # __exit__ calls stage^.release() -> arrive + consumer_step() # Linear type: var stage = pipeline.acquire_consumer() process(stage.index()) stage^.release() # Signal + advance # Explicit signaling: var stage = pipeline.acquire_consumer() if lane_id() < CLUSTER_SIZE: stage.arrive() stage^.release_without_signal() # Advance only ``` Direct API (for special cases): pipeline.wait\_producer() / wait\_consumer() pipeline.producer\_step() / consumer\_step() pipeline.producer\_mbar(stage) / consumer\_mbar(stage) ## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`ConsumeContext`](./ConsumeContext): Context manager for consuming one pipeline stage. * [​`ConsumerStage`](./ConsumerStage): Unified handle for consuming from a pipeline stage. * [​`ExplicitConsumeContext`](./ExplicitConsumeContext): Context manager for consuming with EXPLICIT barrier arrive. * [​`ProduceContext`](./ProduceContext): Context manager for producing one pipeline stage. * [​`ProducerConsumerPipeline`](./ProducerConsumerPipeline): A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps). * [​`ProducerStage`](./ProducerStage): Unified handle for producing to a pipeline stage.
--- ## BarrierPair
`struct BarrierPair[num_stages: Int]` Storage for a producer-consumer barrier pair (full + empty). Each stage has two barriers: * full\[i]: Producer signals when stage i is filled * empty\[i]: Consumer signals when stage i is consumed ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages (ring buffer depth). ## Fields * ​storage (`BarrierPair[num_stages].Array.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `Array` `comptime Array = SMemArray[SharedMemBarrier, (num_stages * 2)]` ## Methods ### `barriers` `barriers(ref[AddressSpace._value._mlir_value] self) -> BarrierPair[num_stages].Array` Get barrier array accessor. **Returns:** `BarrierPair` ### `ptr` `ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Get raw barrier pointer for initialization or custom usage. **Returns:** `MbarPtr` ### `create_pipeline` `create_pipeline(ref[AddressSpace._value._mlir_value] self) -> ProducerConsumerPipeline[num_stages]` Create a runtime pipeline from this barrier storage. **Returns:** [`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline)
--- ## BlockScaledTileStorage
`struct BlockScaledTileStorage[a_type: DType, b_type: DType, c_type: DType, sfa_type: DType, sfb_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, c_dim0: Int, c_dim1: Int, sfa_dim0: Int, sfa_dim1: Int, sfb_dim0: Int, sfb_dim1: Int, num_pipeline_stages: Int, num_output_stages: Int]` Storage for block-scaled matmul tiles (A, B, C, SFA, SFB). Single source of truth for block-scaled tile arrays and storage. All tiles use TileTensor natively. C tiles also store as TileTensor but provide a LayoutTensor accessor for epilogue\_components.mojo compatibility with .reshape\[] and .tile\[] methods. IMPORTANT: Field order preserves SMEM layout compatibility: a, b, c, sfa, sfb. ## Parameters * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A matrix tiles. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B matrix tiles. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for C matrix tiles. * ​sfa\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A scale factor tiles. * ​sfb\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B scale factor tiles. * ​a\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A tiles. * ​a\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A tiles. * ​b\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for B tiles. * ​b\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for B tiles. * ​c\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for C tiles (OutputM). * ​c\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for C tiles (OutputN). * ​sfa\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for SFA tiles. * ​sfa\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for SFA tiles. * ​sfb\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for SFB tiles. * ​sfb\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for SFB tiles. * ​num\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of input pipeline stages. * ​num\_output\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of output pipeline stages. ## Fields * ​a\_tiles\_storage (`BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].ATileArray.Storage`): * ​b\_tiles\_storage (`BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].BTileArray.Storage`): * ​c\_tiles\_storage (`BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].CTileArray.Storage`): * ​sfa\_tiles\_storage (`BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].SFATileArray.Storage`): * ​sfb\_tiles\_storage (`BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].SFBTileArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ### `c_tile_layout` `comptime c_tile_layout = Layout.row_major(c_dim0, c_dim1)` ### `CTileArray` `comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_dim0, c_dim1, num_output_stages]` ### `CTileArrayLT` `comptime CTileArrayLT = SMemTileArray[c_type, BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].c_tile_layout, num_output_stages, 128]` ### `sfa_layout` `comptime sfa_layout = Layout[Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]], Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]]](Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]](Idx[32](), Idx[(sfa_dim0 // 32)]())), Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]](Coord[ComptimeInt[4], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[4], ComptimeInt[4]](Idx[4](), Idx[4]())), Idx[(sfa_dim1 // 16)]())))), Coord[Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]](Idx[16](), Idx[(sfa_dim1 * 32)]())), Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](Coord[ComptimeInt[1], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[4]](Idx[1](), Idx[4]())), Idx[512]())))))` ### `SFATileArray` `comptime SFATileArray = SMemTileArrayWithLayout[sfa_type, BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].sfa_layout, num_pipeline_stages]` ### `sfb_layout` `comptime sfb_layout = Layout[Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]], Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]]](Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]](Idx[32](), Idx[(sfb_dim0 // 32)]())), Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]](Coord[ComptimeInt[4], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[4], ComptimeInt[4]](Idx[4](), Idx[4]())), Idx[(sfb_dim1 // 16)]())))), Coord[Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]](Idx[16](), Idx[(sfb_dim1 * 32)]())), Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](Coord[ComptimeInt[1], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[4]](Idx[1](), Idx[4]())), Idx[512]())))))` ### `SFBTileArray` `comptime SFBTileArray = SMemTileArrayWithLayout[sfb_type, BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].sfb_layout, num_pipeline_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].ATileArray` Get A tile array accessor. **Returns:** `BlockScaledTileStorage` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].BTileArray` Get B tile array accessor. **Returns:** `BlockScaledTileStorage` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].CTileArrayLT` Get C tile array accessor (LayoutTensor-based for backward compat). Returns LayoutTensor view for compatibility with tile\_writer.mojo which uses .reshape\[] and .tile\[] methods. **Returns:** `BlockScaledTileStorage` ### `c_tiles_tt` `c_tiles_tt(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].CTileArray` Get C tile array accessor (TileTensor-based). Returns native TileTensor for future TileTensor-native code paths. **Returns:** `BlockScaledTileStorage` ### `sfa_tiles` `sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].SFATileArray` Get SFA tile array accessor. **Returns:** `BlockScaledTileStorage` ### `sfb_tiles` `sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockScaledTileStorage[a_type, b_type, c_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages, num_output_stages].SFBTileArray` Get SFB tile array accessor. **Returns:** `BlockScaledTileStorage`
--- ## BlockwiseFP8TileStorage
`struct BlockwiseFP8TileStorage[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, c_dim0: Int, c_dim1: Int, a_scales_dim0: Int, a_scales_dim1: Int, num_pipeline_stages: Int, num_output_stages: Int]` Storage for blockwise FP8 matmul tiles (A, B, C, A-scales). Single source of truth for blockwise FP8 tile arrays and storage. B-scales are read directly from global memory during epilogue. All tiles use TileTensor natively. C tiles also store as TileTensor but provide a LayoutTensor accessor for epilogue\_components.mojo compatibility with .reshape\[] and .tile\[] methods. IMPORTANT: Field order preserves SMEM layout compatibility: a, b, c, a\_scales. ## Parameters * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A matrix tiles. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B matrix tiles. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for C matrix tiles. * ​a\_scales\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A scale tiles. * ​a\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A tiles. * ​a\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A tiles. * ​b\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for B tiles. * ​b\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for B tiles. * ​c\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for C tiles (OutputM). * ​c\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for C tiles (OutputN). * ​a\_scales\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A scale tiles. * ​a\_scales\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A scale tiles. * ​num\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of input pipeline stages. * ​num\_output\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of output pipeline stages. ## Fields * ​a\_tiles\_storage (`BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].ATileArray.Storage`): * ​b\_tiles\_storage (`BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].BTileArray.Storage`): * ​c\_tiles\_storage (`BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].CTileArray.Storage`): * ​a\_scales\_tiles\_storage (`BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].AScalesTileArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `AScalesTileArray` `comptime AScalesTileArray = SMemTileArray2DRowMajor[a_scales_type, a_scales_dim0, a_scales_dim1, num_pipeline_stages]` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ### `c_tile_layout` `comptime c_tile_layout = Layout.row_major(c_dim0, c_dim1)` ### `CTileArray` `comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_dim0, c_dim1, num_output_stages]` ### `CTileArrayLT` `comptime CTileArrayLT = SMemTileArray[c_type, BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].c_tile_layout, num_output_stages, 128]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].ATileArray` Get A tile array accessor. **Returns:** `BlockwiseFP8TileStorage` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].BTileArray` Get B tile array accessor. **Returns:** `BlockwiseFP8TileStorage` ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].CTileArrayLT` Get C tile array accessor (LayoutTensor-based for backward compat). Returns LayoutTensor view for compatibility with tile\_writer.mojo which uses .reshape\[] and .tile\[] methods. **Returns:** `BlockwiseFP8TileStorage` ### `c_tiles_tt` `c_tiles_tt(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].CTileArray` Get C tile array accessor (TileTensor-based). Returns native TileTensor for future TileTensor-native code paths. **Returns:** `BlockwiseFP8TileStorage` ### `a_scales_tiles` `a_scales_tiles(ref[AddressSpace._value._mlir_value] self) -> BlockwiseFP8TileStorage[a_type, b_type, c_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, c_dim0, c_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages, num_output_stages].AScalesTileArray` Get A-scales tile array accessor. **Returns:** `BlockwiseFP8TileStorage`
--- ## ClcPipelineStorage
`struct ClcPipelineStorage[num_stages: Int]` Storage for CLC (Cluster Launch Control) scheduler pipeline. CLC has a different barrier pattern: * full/empty: Standard producer-consumer for work items * throttle: Rate limiting barriers (2 per stage) * response: CLC response storage (UInt128 per stage) ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of CLC pipeline stages. ## Fields * ​full\_storage (`ClcPipelineStorage[num_stages].BarrierArray.Storage`): * ​empty\_storage (`ClcPipelineStorage[num_stages].BarrierArray.Storage`): * ​throttle\_storage (`ClcPipelineStorage[num_stages].ThrottleArray.Storage`): * ​response\_storage (`ClcPipelineStorage[num_stages].ResponseArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, num_stages]` ### `ResponseArray` `comptime ResponseArray = SMemArray[UInt128, num_stages]` ### `ThrottleArray` `comptime ThrottleArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ## Methods ### `full` `full(ref[AddressSpace._value._mlir_value] self) -> ClcPipelineStorage[num_stages].BarrierArray` **Returns:** `ClcPipelineStorage` ### `empty` `empty(ref[AddressSpace._value._mlir_value] self) -> ClcPipelineStorage[num_stages].BarrierArray` **Returns:** `ClcPipelineStorage` ### `throttle` `throttle(ref[AddressSpace._value._mlir_value] self) -> ClcPipelineStorage[num_stages].ThrottleArray` **Returns:** `ClcPipelineStorage` ### `response` `response(ref[AddressSpace._value._mlir_value] self) -> ClcPipelineStorage[num_stages].ResponseArray` **Returns:** `ClcPipelineStorage`
--- ## EpiLoadPipelineStorage
`struct EpiLoadPipelineStorage[num_stages: Int]` Storage for epilogue load pipeline (source C loading). For EpilogueLoad warp → Epilogue warps synchronization. The epilogue load warp loads source tensor C into SMEM, and the epilogue warps consume it for residual operations. Producer: EpilogueLoad warp (1 warp, 32 threads) Consumer: Epilogue warps (4 warps, 128 threads) ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of epilogue load pipeline stages (typically 2). ## Fields * ​barriers (`BarrierPair[num_stages]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ## Methods ### `create_pipeline` `create_pipeline(ref[AddressSpace._value._mlir_value] self) -> ProducerConsumerPipeline[num_stages]` Create runtime pipeline from this storage. **Returns:** [`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline) ### `barrier_ptr` `barrier_ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Escape hatch: Get raw barrier pointer. **Returns:** `MbarPtr`
--- ## InputPipelineStorage
`struct InputPipelineStorage[num_stages: Int, Payload: TilePayload]` Unified storage for input tile pipeline (barriers + payload). Bundles barrier storage with tile payload storage, ensuring they're always consistent. The pipeline can only be created from matching storage. Example: ``` struct MySmem[...]: var input: InputPipelineStorage[ 4, # 4 stages StandardTilePayload[float16, float16, a_layout, b_layout], ] fn get_pipeline(ref[SHARED] self): return self.input.create_pipeline() ``` ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages. * ​Payload ([`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload)): Tile payload type (defines what's in each stage). ## Fields * ​barriers (`BarrierPair[num_stages]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ## Methods ### `create_pipeline` `create_pipeline(ref[AddressSpace._value._mlir_value] self) -> ProducerConsumerPipeline[num_stages]` Create runtime pipeline from this storage. **Returns:** [`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline) ### `barrier_ptr` `barrier_ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Escape hatch: Get raw barrier pointer for custom initialization. **Returns:** `MbarPtr`
--- ## LoadOrderBarrierStorage
`struct LoadOrderBarrierStorage` Storage for load order barrier (mainloop → epilogue load coordination). This single barrier coordinates the mainloop load warp with the epilogue load warp, ensuring the epilogue load doesn't start before the mainloop has issued its prologue TMA operations. Protocol: 1. Mainloop load warp issues prologue loads 2. Mainloop load warp calls arrive() on this barrier 3. Epilogue load warp waits on this barrier before starting This prevents TMA resource contention between mainloop and epilogue loads. ## Fields * ​barrier\_storage (`LoadOrderBarrierStorage.BarrierArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, 1]` ## Methods ### `barrier` `barrier(ref[AddressSpace._value._mlir_value] self) -> LoadOrderBarrierStorage.BarrierArray` Get the load order barrier. **Returns:** `LoadOrderBarrierStorage.BarrierArray` ### `ptr` `ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Get raw barrier pointer for initialization. **Returns:** `MbarPtr`
--- ## OutputPipelineStorage
`struct OutputPipelineStorage[num_stages: Int]` Unified storage for output/accumulator pipeline. For MMA → Epilogue synchronization. TMEM stages are allocated dynamically, so this only stores barriers. ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. ## Fields * ​barriers (`BarrierPair[num_stages]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ## Methods ### `create_pipeline` `create_pipeline(ref[AddressSpace._value._mlir_value] self) -> ProducerConsumerPipeline[num_stages]` Create runtime pipeline from this storage. **Returns:** [`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline) ### `barrier_ptr` `barrier_ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Escape hatch: Get raw barrier pointer. **Returns:** `MbarPtr`
--- ## OutputTileStorage
`struct OutputTileStorage[c_type: DType, c_dim0: Int, c_dim1: Int, num_output_stages: Int]` Storage for output tiles (C matrix). Single source of truth for output tile array and storage. Separate from input tiles since output has different stage count. All tiles use TileTensor natively. C tiles also store as TileTensor but provide a LayoutTensor accessor for epilogue\_components.mojo compatibility with .reshape\[] and .tile\[] methods. ## Parameters * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for C matrix tiles. * ​c\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for C tiles (OutputM). * ​c\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for C tiles (OutputN). * ​num\_output\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of output pipeline stages. ## Fields * ​c\_tiles\_storage (`OutputTileStorage[c_type, c_dim0, c_dim1, num_output_stages].CTileArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `c_tile_layout` `comptime c_tile_layout = Layout.row_major(c_dim0, c_dim1)` ### `CTileArray` `comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_dim0, c_dim1, num_output_stages]` ### `CTileArrayLT` `comptime CTileArrayLT = SMemTileArray[c_type, OutputTileStorage[c_type, c_dim0, c_dim1, num_output_stages].c_tile_layout, num_output_stages, 128]` ## Methods ### `c_tiles` `c_tiles(ref[AddressSpace._value._mlir_value] self) -> OutputTileStorage[c_type, c_dim0, c_dim1, num_output_stages].CTileArrayLT` Get C tile array accessor (LayoutTensor-based for backward compat). Returns LayoutTensor view for compatibility with tile\_writer.mojo which uses .reshape\[] and .tile\[] methods. **Returns:** `OutputTileStorage` ### `c_tiles_tt` `c_tiles_tt(ref[AddressSpace._value._mlir_value] self) -> OutputTileStorage[c_type, c_dim0, c_dim1, num_output_stages].CTileArray` Get C tile array accessor (TileTensor-based). Returns native TileTensor for future TileTensor-native code paths. **Returns:** `OutputTileStorage`
--- ## RawBarrierStorage
`struct RawBarrierStorage[count: Int]` Escape hatch: Raw barrier storage for custom patterns. Use this when the standard pipeline storage doesn't fit your needs. You're responsible for initialization and synchronization semantics. Example: ``` # Custom barrier layout for specialized synchronization struct MyCustomSmem: var custom_barriers: RawBarrierStorage[8] fn init_custom(ref[SHARED] self): ptr = self.custom_barriers.ptr() # Custom initialization... ``` ## Parameters * ​count ([`Int`](/mojo/std/builtin/int/Int)): Total number of barriers to allocate. ## Fields * ​storage (`RawBarrierStorage[count].Array.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `Array` `comptime Array = SMemArray[SharedMemBarrier, count]` ## Methods ### `barriers` `barriers(ref[AddressSpace._value._mlir_value] self) -> RawBarrierStorage[count].Array` **Returns:** `RawBarrierStorage` ### `ptr` `ptr(ref[AddressSpace._value._mlir_value] self) -> MbarPtr` Get raw pointer for custom usage. **Returns:** `MbarPtr`
--- ## SmemLayouts
`struct SmemLayouts[a_type: DType, b_type: DType, BM: Int, BN: Int, BK: Int, OutputM: Int, OutputN: Int, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool]` Common SMEM layout definitions for matmul-family kernels. Centralizes the A/B/C tile layout computation including the transpose-conditional B layout logic, eliminating \~10 lines of duplicated layout definitions from each SMEM struct. ## Parameters * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A matrix tiles. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B matrix tiles. * ​BM ([`Int`](/mojo/std/builtin/int/Int)): Block tile M dimension. * ​BN ([`Int`](/mojo/std/builtin/int/Int)): Block tile N dimension. * ​BK ([`Int`](/mojo/std/builtin/int/Int)): Block tile K dimension. * ​OutputM ([`Int`](/mojo/std/builtin/int/Int)): Output tile M dimension. * ​OutputN ([`Int`](/mojo/std/builtin/int/Int)): Output tile N dimension. * ​a\_swizzle ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzle mode for A tiles. * ​b\_swizzle ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzle mode for B tiles. * ​transpose\_b ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether B is transposed (K-major). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BM, BK, a_swizzle]()` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BN, BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BN, BK, b_swizzle]()` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(OutputM, OutputN)`
--- ## SmemPipelineBundle
`struct SmemPipelineBundle[num_group_pipeline_stages: Int, num_accum_pipeline_stages: Int, num_clc_pipeline_stages: Int, Payload: TilePayload]` Composed pipeline storage with unified barrier accessors. Bundles InputPipelineStorage, OutputPipelineStorage, ClcPipelineStorage, and TmemDeallocStorage into a single composed struct, eliminating \~60 lines of duplicated pipeline declarations, barrier type aliases, and barrier accessor methods from each SMEM struct. ## Parameters * ​num\_group\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of grouped pipeline stages for input. * ​num\_accum\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​num\_clc\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of CLC scheduler pipeline stages. * ​Payload ([`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload)): Tile payload type (e.g. StandardTilePayload, BlockScaledTilePayload). ## Fields * ​input\_pipeline (`SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].InputPipeline`): * ​output\_pipeline (`SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].OutputPipeline`): * ​clc\_pipeline (`SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcPipeline`): * ​tmem\_dealloc\_pipeline (`SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].TmemDeallocPipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `AccumBarriers` `comptime AccumBarriers = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].OutputPipeline.BarrierArray` ### `ClcBarriers` `comptime ClcBarriers = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcPipeline.BarrierArray` ### `ClcPipeline` `comptime ClcPipeline = ClcPipelineStorage[num_clc_pipeline_stages]` ### `ClcResponse` `comptime ClcResponse = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcPipeline.ResponseArray` ### `ClcThrottleBarriers` `comptime ClcThrottleBarriers = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcPipeline.ThrottleArray` ### `InputBarriers` `comptime InputBarriers = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].InputPipeline.BarrierArray` ### `InputPipeline` `comptime InputPipeline = InputPipelineStorage[num_group_pipeline_stages, Payload]` ### `OutputPipeline` `comptime OutputPipeline = OutputPipelineStorage[num_accum_pipeline_stages]` ### `TmemAddr` `comptime TmemAddr = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].TmemDeallocPipeline.AddrArray` ### `TmemDealloc` `comptime TmemDealloc = SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].TmemDeallocPipeline.BarrierArray` ### `TmemDeallocPipeline` `comptime TmemDeallocPipeline = TmemDeallocStorage` ## Methods ### `input_barriers` `input_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].InputBarriers` Returns input tile pipeline barriers. **Returns:** `SmemPipelineBundle` ### `accum_barriers` `accum_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].AccumBarriers` Returns accumulator pipeline barriers. **Returns:** `SmemPipelineBundle` ### `clc_full` `clc_full(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcBarriers` Returns CLC full barriers. **Returns:** `SmemPipelineBundle` ### `clc_empty` `clc_empty(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcBarriers` Returns CLC empty barriers. **Returns:** `SmemPipelineBundle` ### `clc_throttle` `clc_throttle(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcThrottleBarriers` Returns CLC throttle barriers. **Returns:** `SmemPipelineBundle` ### `clc_response` `clc_response(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].ClcResponse` Returns CLC response storage. **Returns:** `SmemPipelineBundle` ### `tmem_dealloc` `tmem_dealloc(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].TmemDealloc` Returns TMEM deallocation barrier. **Returns:** `SmemPipelineBundle` ### `tmem_addr` `tmem_addr(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundle[num_group_pipeline_stages, num_accum_pipeline_stages, num_clc_pipeline_stages, Payload].TmemAddr` Returns TMEM address storage. **Returns:** `SmemPipelineBundle`
--- ## SmemPipelineBundleNoClc
`struct SmemPipelineBundleNoClc[num_group_pipeline_stages: Int, num_accum_pipeline_stages: Int, Payload: TilePayload]` Composed pipeline storage without CLC scheduler. Used by kernels with 3-warp specialization (Load, MMA, Epilogue) that don't use a scheduler warp (e.g. Grouped1D1DSmem). ## Parameters * ​num\_group\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of grouped pipeline stages for input. * ​num\_accum\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​Payload ([`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload)): Tile payload type. ## Fields * ​input\_pipeline (`SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].InputPipeline`): * ​output\_pipeline (`SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].OutputPipeline`): * ​tmem\_dealloc\_pipeline (`SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].TmemDeallocPipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `AccumBarriers` `comptime AccumBarriers = SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].OutputPipeline.BarrierArray` ### `InputBarriers` `comptime InputBarriers = SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].InputPipeline.BarrierArray` ### `InputPipeline` `comptime InputPipeline = InputPipelineStorage[num_group_pipeline_stages, Payload]` ### `OutputPipeline` `comptime OutputPipeline = OutputPipelineStorage[num_accum_pipeline_stages]` ### `TmemAddr` `comptime TmemAddr = SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].TmemDeallocPipeline.AddrArray` ### `TmemDealloc` `comptime TmemDealloc = SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].TmemDeallocPipeline.BarrierArray` ### `TmemDeallocPipeline` `comptime TmemDeallocPipeline = TmemDeallocStorage` ## Methods ### `input_barriers` `input_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].InputBarriers` Returns input tile pipeline barriers. **Returns:** `SmemPipelineBundleNoClc` ### `accum_barriers` `accum_barriers(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].AccumBarriers` Returns accumulator pipeline barriers. **Returns:** `SmemPipelineBundleNoClc` ### `tmem_dealloc` `tmem_dealloc(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].TmemDealloc` Returns TMEM deallocation barrier. **Returns:** `SmemPipelineBundleNoClc` ### `tmem_addr` `tmem_addr(ref[AddressSpace._value._mlir_value] self) -> SmemPipelineBundleNoClc[num_group_pipeline_stages, num_accum_pipeline_stages, Payload].TmemAddr` Returns TMEM address storage. **Returns:** `SmemPipelineBundleNoClc`
--- ## SourceTileStorage
`struct SourceTileStorage[src_type: DType, src_dim0: Int, src_dim1: Int, num_epi_load_stages: Int]` Storage for source tensor C tiles (residual/skip connection input). Used by the epilogue load warp to pre-fetch source tensor C via TMA, enabling overlap with MMA computation for residual operations like D = Conv(A,B) + beta*C or D = MatMul(A,B) + beta*C. ## Parameters * ​src\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for source tiles (same as output type). * ​src\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for source tiles (OutputM). * ​src\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for source tiles (OutputN). * ​num\_epi\_load\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of epilogue load pipeline stages. ## Fields * ​src\_tiles\_storage (`SourceTileStorage[src_type, src_dim0, src_dim1, num_epi_load_stages].SrcTileArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `src_tile_layout` `comptime src_tile_layout = Layout.row_major(src_dim0, src_dim1)` ### `SrcTileArray` `comptime SrcTileArray = SMemTileArray2DRowMajor[src_type, src_dim0, src_dim1, num_epi_load_stages]` ### `SrcTileArrayLT` `comptime SrcTileArrayLT = SMemTileArray[src_type, SourceTileStorage[src_type, src_dim0, src_dim1, num_epi_load_stages].src_tile_layout, num_epi_load_stages, 128]` ## Methods ### `src_tiles` `src_tiles(ref[AddressSpace._value._mlir_value] self) -> SourceTileStorage[src_type, src_dim0, src_dim1, num_epi_load_stages].SrcTileArrayLT` Get source tile array accessor (LayoutTensor-based). Returns LayoutTensor view for compatibility with tile\_writer.mojo. **Returns:** `SourceTileStorage` ### `src_tiles_tt` `src_tiles_tt(ref[AddressSpace._value._mlir_value] self) -> SourceTileStorage[src_type, src_dim0, src_dim1, num_epi_load_stages].SrcTileArray` Get source tile array accessor (TileTensor-based). Returns native TileTensor for future TileTensor-native code paths. **Returns:** `SourceTileStorage`
--- ## StandardTileStorage
`struct StandardTileStorage[a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int]` Storage for standard matmul tiles (A and B). This is the single source of truth for tile array types and storage. SMEM structs embed this rather than defining tile arrays separately. All tiles use TileTensor natively. Convert to LayoutTensor at TMA/MMA boundaries using {ptr} syntax or explicit construction. ## Parameters * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A matrix tiles. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B matrix tiles. * ​a\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A tiles. * ​a\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A tiles. * ​b\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for B tiles. * ​b\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for B tiles. * ​num\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages (determines array depth). ## Fields * ​a\_tiles\_storage (`StandardTileStorage[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATileArray.Storage`): * ​b\_tiles\_storage (`StandardTileStorage[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTileArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> StandardTileStorage[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `StandardTileStorage` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> StandardTileStorage[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `StandardTileStorage`
--- ## TmemDeallocStorage
`struct TmemDeallocStorage` Storage for TMEM deallocation synchronization. Single barrier + address storage for TMEM lifecycle management. ## Fields * ​barrier\_storage (`TmemDeallocStorage.BarrierArray.Storage`): * ​addr\_storage (`TmemDeallocStorage.AddrArray.Storage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `AddrArray` `comptime AddrArray = SMemArray[UInt32, 1]` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, 1]` ## Methods ### `barrier` `barrier(ref[AddressSpace._value._mlir_value] self) -> TmemDeallocStorage.BarrierArray` **Returns:** `TmemDeallocStorage.BarrierArray` ### `addr` `addr(ref[AddressSpace._value._mlir_value] self) -> TmemDeallocStorage.AddrArray` **Returns:** `TmemDeallocStorage.AddrArray`
--- ## pipeline_storage
Unified Pipeline Storage Framework for SM100 Structured Kernels. This module provides a single-source-of-truth framework for pipeline storage, where stage count determines barrier count, and tile storage type determines the SMEM layout for input tiles. All tile storage uses TileTensor natively. Conversion to LayoutTensor only happens at external API boundaries (TMA, MMA) using the {ptr} syntax or explicit LayoutTensor construction. ## Design Principles 1. **Single Source of Truth**: Stage count parameterizes barrier count 2. **Single Source of Truth**: Tile storage types define array types once 3. **TileTensor Native**: All SMEM tiles use TileTensor 4. **Composable**: SMEM structs compose storage objects 5. **Extensible**: Easy to add new storage types 6. **Escape Hatch**: Raw storage access when framework doesn't fit ## Architecture ``` ┌─────────────────────────────────────────────────────────────────────┐ │ Tile Storage (defines tile arrays and storage) │ │ │ │ StandardTileStorage[a_type, b_type, a_dim0, a_dim1, b_dim0, ...] │ │ ├── ATileArray = SMemTileArray2D[...] # TileTensor-based │ │ ├── BTileArray = SMemTileArray2D[...] # TileTensor-based │ │ ├── var a_tiles_storage │ │ ├── var b_tiles_storage │ │ └── fn a_tiles(), b_tiles() # Returns TileTensor │ │ │ │ BlockScaledTileStorage[..., sfa_type, sfb_type, dims, ...] │ │ BlockwiseFP8TileStorage[..., a_scales_type, dims, ...] │ │ OutputTileStorage[c_type, c_layout, num_stages] │ ├─────────────────────────────────────────────────────────────────────┤ │ Pipeline Storage (defines barriers) │ │ │ │ InputPipelineStorage[num_stages, Payload] │ │ └── var barriers: BarrierPair[num_stages] │ │ │ │ OutputPipelineStorage[num_stages] │ │ ClcPipelineStorage[num_stages] │ │ TmemDeallocStorage │ ├─────────────────────────────────────────────────────────────────────┤ │ SMEM composes both: │ │ │ │ struct MySmem: │ │ var tiles: StandardTileStorage[...] # Tile storage │ │ var output_tiles: OutputTileStorage[...] # Output tiles │ │ var input_pipeline: InputPipelineStorage[...] # Barriers │ │ var output_pipeline: OutputPipelineStorage[...] │ │ var clc_pipeline: ClcPipelineStorage[...] │ └─────────────────────────────────────────────────────────────────────┘ ``` ## Example Usage ``` struct MyKernelSmem[config: MyConfig]: # Tile storage (single source of truth for tile types) comptime Tiles = StandardTileStorage[ config.a_type, config.b_type, config.BM, config.BK, # A tile dimensions config.BN, config.BK, # B tile dimensions config.num_pipeline_stages, ] var tiles: Self.Tiles # Output tile storage (separate stage count) comptime OutputTiles = OutputTileStorage[ config.c_type, config.c_layout, config.num_output_stages ] var output_tiles: Self.OutputTiles # Pipeline storage (barriers) var input_pipeline: InputPipelineStorage[...] var output_pipeline: OutputPipelineStorage[...] # Accessors delegate to composed storage fn a_tiles(ref[SHARED] self) -> Self.Tiles.ATileArray: return self.tiles.a_tiles() # Returns TileTensor fn c_tiles(ref[SHARED] self) -> Self.OutputTiles.CTileArray: return self.output_tiles.c_tiles() # Returns LayoutTensor ``` ## Extensibility To add a new tile storage type: 1. Create a new struct with comptime type aliases and storage fields 2. Add accessors that construct tile arrays from storage 3. Use in SMEM via composition ## Escape Hatch When the framework doesn't fit: 1. Use raw SMemArray for custom tile layouts 2. Use RawBarrierStorage for non-standard barrier patterns 3. Add custom storage fields to SMEM struct ## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ## Structs * [​`BarrierPair`](./BarrierPair): Storage for a producer-consumer barrier pair (full + empty). * [​`BlockScaledTileStorage`](./BlockScaledTileStorage): Storage for block-scaled matmul tiles (A, B, C, SFA, SFB). * [​`BlockwiseFP8TileStorage`](./BlockwiseFP8TileStorage): Storage for blockwise FP8 matmul tiles (A, B, C, A-scales). * [​`ClcPipelineStorage`](./ClcPipelineStorage): Storage for CLC (Cluster Launch Control) scheduler pipeline. * [​`EpiLoadPipelineStorage`](./EpiLoadPipelineStorage): Storage for epilogue load pipeline (source C loading). * [​`InputPipelineStorage`](./InputPipelineStorage): Unified storage for input tile pipeline (barriers + payload). * [​`LoadOrderBarrierStorage`](./LoadOrderBarrierStorage): Storage for load order barrier (mainloop → epilogue load coordination). * [​`OutputPipelineStorage`](./OutputPipelineStorage): Unified storage for output/accumulator pipeline. * [​`OutputTileStorage`](./OutputTileStorage): Storage for output tiles (C matrix). * [​`RawBarrierStorage`](./RawBarrierStorage): Escape hatch: Raw barrier storage for custom patterns. * [​`SmemLayouts`](./SmemLayouts): Common SMEM layout definitions for matmul-family kernels. * [​`SmemPipelineBundle`](./SmemPipelineBundle): Composed pipeline storage with unified barrier accessors. * [​`SmemPipelineBundleNoClc`](./SmemPipelineBundleNoClc): Composed pipeline storage without CLC scheduler. * [​`SourceTileStorage`](./SourceTileStorage): Storage for source tensor C tiles (residual/skip connection input). * [​`StandardTileStorage`](./StandardTileStorage): Storage for standard matmul tiles (A and B). * [​`TmemDeallocStorage`](./TmemDeallocStorage): Storage for TMEM deallocation synchronization.
--- ## ScalesTileLoader
`@register_passable(trivial)` `struct ScalesTileLoader[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]` TMA-based scales tile loader for blockwise FP8. Unlike TileLoaderTMA, this loader: * Uses async\_copy (no multicast) since scales aren't distributed across CTAs * Uses (row\_coord, k\_coord) coordinate order matching scales tensor layout ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin of the TMA descriptor pointer. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type (typically float8 for scales). * ​gmem\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory tensor layout. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout (tile dimensions). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2 for SM100 2-SM MMA). ## Fields * ​tma\_op (`ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TmaOp` `comptime TmaOp = TMATensorTile[dtype, gmem_layout, desc_layout]` ### `TmaOpPtr` `comptime TmaOpPtr = Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]) -> Self` Initialize the scales tile loader. **Args:** * ​tma\_op ([`Pointer`](/mojo/std/memory/pointer/Pointer)): Pointer to TMA descriptor (grid constant). ### `load` `load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, row_coord: Int, k_coord: Int)` Load a scales tile using TMA hardware acceleration. Issues an async copy from global memory to shared memory. Unlike TileLoaderTMA, this uses (row\_coord, k\_coord) order matching the scales tensor layout. **Args:** * ​dest ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination SMEM tile. * ​barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for TMA completion signaling. * ​row\_coord ([`Int`](/mojo/std/builtin/int/Int)): Row coordinate (M for A-scales) in global memory. * ​k\_coord ([`Int`](/mojo/std/builtin/int/Int)): K dimension coordinate in global memory. `load[dim0: Int, dim1: Int, /, alignment: Int = 128](self, dest: TileTensor[dtype, Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, row_coord: Int, k_coord: Int)` Load a TileTensor scales tile using TMA hardware acceleration. This overload accepts TileTensor-based tiles and converts them to LayoutTensor internally for the TMA operation. Zero-cost conversion via pointer reinterpretation. **Args:** * ​dest ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Destination SMEM TileTensor tile. * ​barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for TMA completion signaling. * ​row\_coord ([`Int`](/mojo/std/builtin/int/Int)): Row coordinate (M for A-scales) in global memory. * ​k\_coord ([`Int`](/mojo/std/builtin/int/Int)): K dimension coordinate in global memory.
--- ## TileLoaderTMA
`@register_passable(trivial)` `struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]` TMA-based tile loader for SM100. Wraps a TMA descriptor and multicast mask for efficient tile loading. The load method issues async\_multicast\_load with proper CTA group handling. ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin of the TMA descriptor pointer. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element data type. * ​gmem\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory tensor layout. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout (tile dimensions). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2 for SM100 2-SM MMA). ## Fields * ​tma\_op (`TileLoaderTMA[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr`): * ​multicast\_mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TmaOp` `comptime TmaOp = TMATensorTile[dtype, gmem_layout, desc_layout]` ### `TmaOpPtr` `comptime TmaOpPtr = Pointer[TileLoaderTMA[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[TileLoaderTMA[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin], multicast_mask: UInt16) -> Self` Initialize the TMA tile loader. **Args:** * ​tma\_op ([`Pointer`](/mojo/std/memory/pointer/Pointer)): Pointer to TMA descriptor (grid constant). * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for cluster distribution. ### `load` `load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], row_coord: Scalar[DType.uint])` Load a tile using TMA hardware acceleration. Issues an async multicast load from global memory to shared memory. Coordinates are in element units (not tile units). **Args:** * ​dest ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination SMEM tile (already sliced for peer CTA if needed). * ​barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for TMA completion signaling. * ​k\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K dimension coordinate in global memory (elements). * ​row\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Row coordinate (M for A, N for B) in global memory (elements). `load[dim0: Int, dim1: Int, /, alignment: Int = 128](self, dest: TileTensor[dtype, Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], row_coord: Scalar[DType.uint])` Load a TileTensor tile using TMA hardware acceleration. This overload accepts TileTensor-based tiles and passes them directly to the TMA TileTensor overload (no LayoutTensor conversion needed). **Args:** * ​dest ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Destination SMEM TileTensor tile. * ​barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for TMA completion signaling. * ​k\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K dimension coordinate in global memory (elements). * ​row\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Row coordinate (M for A, N for B) in global memory (elements). `load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], row_coord: Scalar[DType.uint])` Load a TileTensor tile with variadic shape/stride types using TMA. This overload accepts TileTensor tiles with swizzled layouts (created via internal\_k\_major) and passes them to the TMA operation. **Args:** * ​dest ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Destination SMEM TileTensor tile with swizzled layout. * ​barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): Memory barrier for TMA completion signaling. * ​k\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K dimension coordinate in global memory (elements). * ​row\_coord ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Row coordinate (M for A, N for B) in global memory (elements).
--- ## tile_loader
TMA tile loader for SM100 matrix multiplication. Provides a wrapper around TMA async\_multicast\_load operations, following the SM90 TileLoaderTMA pattern. Orchestration logic (k-group iteration, expect\_bytes, barrier management) is handled by the kernel, not the loader. Usage: \# In kernel - create separate A and B loaders var a\_loader = ATileLoaderType(Pointer(to=a\_tma\_op), ctx.a\_multicast\_mask) var b\_loader = BTileLoaderType(Pointer(to=b\_tma\_op), ctx.b\_multicast\_mask) ``` # Load tiles using the loaders (LayoutTensor or TileTensor) a_loader.load(a_tile, barrier, k_coord, m_coord) b_loader.load(b_tile, barrier, k_coord, n_coord) # TileTensor tiles are automatically converted to LayoutTensor for TMA ops ``` ## Structs * [​`ScalesTileLoader`](./ScalesTileLoader): TMA-based scales tile loader for blockwise FP8. * [​`TileLoaderTMA`](./TileLoaderTMA): TMA-based tile loader for SM100.
--- ## BlockScaledTilePayload
`@register_passable(trivial)` `struct BlockScaledTilePayload[a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, sfa_dim0: Int, sfa_dim1: Int, sfb_dim0: Int, sfb_dim1: Int, num_pipeline_stages: Int]` Tile payload for block-scaled matmul (A, B, SFA, SFB tiles). ## Fields * ​a\_tiles (`BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATileArray`): * ​b\_tiles (`BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTileArray`): * ​sfa\_tiles (`BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATileArray`): * ​sfb\_tiles (`BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTileArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTile` `comptime BTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ### `sfa_layout` `comptime sfa_layout = Layout[Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]], Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]]](Coord[ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[32], ComptimeInt[(sfa_dim0 // 32)]](Idx[32](), Idx[(sfa_dim0 // 32)]())), Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfa_dim1 // 16)]](Coord[ComptimeInt[4], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[4], ComptimeInt[4]](Idx[4](), Idx[4]())), Idx[(sfa_dim1 // 16)]())))), Coord[Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[16], ComptimeInt[(sfa_dim1 * 32)]](Idx[16](), Idx[(sfa_dim1 * 32)]())), Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](Coord[ComptimeInt[1], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[4]](Idx[1](), Idx[4]())), Idx[512]())))))` ### `SFATile` `comptime SFATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATileArray.Tile` ### `SFATileArray` `comptime SFATileArray = SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfa_layout, num_pipeline_stages]` ### `sfb_layout` `comptime sfb_layout = Layout[Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]], Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]]](Coord[ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[32], ComptimeInt[(sfb_dim0 // 32)]](Idx[32](), Idx[(sfb_dim0 // 32)]())), Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(sfb_dim1 // 16)]](Coord[ComptimeInt[4], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[4], ComptimeInt[4]](Idx[4](), Idx[4]())), Idx[(sfb_dim1 // 16)]())))), Coord[Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[16], ComptimeInt[(sfb_dim1 * 32)]](Idx[16](), Idx[(sfb_dim1 * 32)]())), Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](Coord[ComptimeInt[1], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[4]](Idx[1](), Idx[4]())), Idx[512]())))))` ### `SFBTile` `comptime SFBTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTileArray.Tile` ### `SFBTileArray` `comptime SFBTileArray = SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfb_layout, num_pipeline_stages]` ## Methods ### `__init__` `__init__(a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages], sfa_tiles: SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfa_layout, num_pipeline_stages], sfb_tiles: SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfb_layout, num_pipeline_stages]) -> Self` ### `get_tile` `get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTile]` Get A, B, SFA, SFB tiles at the specified stage and k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATile` Get A tile at the specified stage and k-group index. **Returns:** `BlockScaledTilePayload` ### `get_b_tile` `get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTile` Get B tile at the specified stage and k-group index. **Returns:** `BlockScaledTilePayload` ### `get_sfa_tile` `get_sfa_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATile` Get SFA tile at the specified stage and k-group index. **Returns:** `BlockScaledTilePayload` ### `get_sfb_tile` `get_sfb_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTile` Get SFB tile at the specified stage and k-group index. **Returns:** `BlockScaledTilePayload`
--- ## BlockwiseFP8TilePayload
`@register_passable(trivial)` `struct BlockwiseFP8TilePayload[a_type: DType, b_type: DType, a_scales_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, a_scales_dim0: Int, a_scales_dim1: Int, num_pipeline_stages: Int]` Tile payload for blockwise FP8 matmul (A, B, A-scales tiles). Unlike BlockScaledTilePayload, this only stores A-scales in SMEM. B-scales are read directly from global memory during the epilogue phase. ## Fields * ​a\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATileArray`): * ​b\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTileArray`): * ​a\_scales\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTileArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AScalesTile` `comptime AScalesTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTileArray.Tile` ### `AScalesTileArray` `comptime AScalesTileArray = SMemTileArray2D[a_scales_type, a_scales_dim0, a_scales_dim1, num_pipeline_stages]` ### `ATile` `comptime ATile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTile` `comptime BTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ## Methods ### `__init__` `__init__(a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages], a_scales_tiles: SMemTileArray2D[a_scales_type, a_scales_dim0, a_scales_dim1, num_pipeline_stages]) -> Self` ### `get_tile` `get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTile]` Get A, B, A-scales tiles at the specified stage and k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATile` Get A tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload` ### `get_b_tile` `get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTile` Get B tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload` ### `get_a_scales_tile` `get_a_scales_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTile` Get A-scales tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload`
--- ## EpilogueKContext
`@register_passable(trivial)` `struct EpilogueKContext[origin: MutOrigin, input_origin: MutOrigin, num_output_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int]` Per-K context manager for epilogue warp in blockwise FP8. Bundles output pipeline (MMA→Epilogue sync) and input pipeline (A-scales) into a single context manager for clean per-K iteration handling. Example usage: for k\_iter in range(num\_iters): with epi\_ctx.per\_k\_stage(input\_pipeline) as epi\_stage: accum.promote(epi\_stage, ...) \# **exit** signals BOTH pipelines **enter**: Waits for MMA to complete this K iteration, returns EpilogueKStage **exit**: Signals both output consumer barrier AND input consumer\_step ## Fields * ​output\_pipeline\_ptr (`Pointer[EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].OutputPipelineType, origin]`): * ​input\_pipeline\_ptr (`Pointer[EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].InputPipelineType, input_origin]`): * ​output\_stage (`EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].OutputStageType`): * ​input\_stage\_index (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CombinedStageType` `comptime CombinedStageType = EpilogueKStage[num_output_stages, stage_stride_cols, cta_group, num_input_stages]` ### `InputPipelineType` `comptime InputPipelineType = ProducerConsumerPipeline[num_input_stages]` ### `OutputPipelineType` `comptime OutputPipelineType = OutputTilePipeline[num_output_stages, stage_stride_cols, cta_group]` ### `OutputStageType` `comptime OutputStageType = OutputStage[num_output_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(output_pipeline_ptr: Pointer[EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].OutputPipelineType, origin], input_pipeline_ptr: Pointer[EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].InputPipelineType, input_origin]) -> Self` ### `__enter__` `__enter__(mut self) -> EpilogueKContext[origin, input_origin, num_output_stages, stage_stride_cols, cta_group, num_input_stages].CombinedStageType` **Returns:** `EpilogueKContext` ### `__exit__` `__exit__(mut self)`
--- ## EpilogueKStage
`@register_passable(trivial)` `struct EpilogueKStage[num_output_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int]` Per-K stage for epilogue warp in blockwise FP8. Returned from `EpilogueKContext.__enter__()`. Bundles: * output\_stage: TMEM access (offset for reading MMA results) * input\_stage\_index: Current A-scales stage * input\_pipeline: For signaling A-scales consumption ## Fields * ​output\_stage (`EpilogueKStage[num_output_stages, stage_stride_cols, cta_group, num_input_stages].OutputStageType`): * ​input\_stage\_index (`UInt32`): * ​input\_pipeline (`EpilogueKStage[num_output_stages, stage_stride_cols, cta_group, num_input_stages].InputPipelineType`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `InputPipelineType` `comptime InputPipelineType = ProducerConsumerPipeline[num_input_stages]` ### `OutputStageType` `comptime OutputStageType = OutputStage[num_output_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(output_stage: OutputStage[num_output_stages, stage_stride_cols, cta_group], input_stage_index: UInt32, input_pipeline: ProducerConsumerPipeline[num_input_stages]) -> Self` ### `arrive_input` `arrive_input(self)` Arrive on the input pipeline's consumer barrier. Use with lane-guarded patterns: if lane\_id() < cluster\_size: epi\_stage.arrive\_input()
--- ## EpilogueStage
`struct EpilogueStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Unified linear type handle for epilogue stage in output pipeline. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `output_pipeline.acquire_epilogue_linear()` - waits for MMA 2. Use `tmem()`, `tmem_offset()` for reading MMA results 3. Must call `release()` to advance (compiler-enforced) ## Parameters * ​origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages. * ​stage\_stride\_cols ([`Int`](/mojo/std/builtin/int/Int)): TMEM column stride between stages. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). ## Fields * ​pipeline\_ptr (`Pointer[EpilogueStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(out self, pipeline_ptr: Pointer[EpilogueStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin], stage: OutputStage[num_stages, stage_stride_cols, cta_group])` ### `tmem` `tmem(self) -> EpilogueStage[origin, num_stages, stage_stride_cols, cta_group].Stage.Tmem` Get the TMEM stage handle. **Returns:** `EpilogueStage` ### `tmem_offset` `tmem_offset(self) -> Int` Get the TMEM offset for reading MMA results. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `release` `release(deinit self)` Free stage for MMA reuse and advance to next stage. This is the only way to destroy this linear type.
--- ## InputConsumer
`@register_passable(trivial)` `struct InputConsumer[origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int]` Consumer view for MMA warp. Use acquire() to get stages. ## Fields * ​pipeline\_ptr (`Pointer[InputConsumer[origin, Payload, num_group_stages, k_group_size].PipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `PipelineType` `comptime PipelineType = InputTilePipeline[Payload, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `acquire` `acquire(mut self) -> InputConsumerStage[origin, Payload, num_group_stages, k_group_size]` Acquire next stage, waiting for tiles to be ready. Returns a context manager for processing tiles. **Returns:** [`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage) ### `try_acquire` `try_acquire(mut self) -> Bool` Non-blocking check if next consumer stage has data. Use with acquire\_if\_needed() for the try-acquire pattern: ``` var ready = consumer.try_acquire() # ... do other work ... with consumer.acquire_if_needed(ready) as tiles: process_tiles() ``` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the stage has data, False if waiting is needed. ### `acquire_if_needed` `acquire_if_needed(mut self, already_ready: Bool) -> InputConsumerStage[origin, Payload, num_group_stages, k_group_size]` Acquire stage, only waiting if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_acquire(). Skips wait if True. **Returns:** [`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage): The consumer stage for processing tiles.
--- ## InputConsumerStage
`@register_passable(trivial)` `struct InputConsumerStage[origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int]` Handle for consumer tile access - works as context manager or linear-style. Two usage patterns: 1. Context manager (scoped): with consumer.acquire() as tiles: process\_tiles(tiles.payload(), tiles.stage()) # release() called automatically by **exit** 2. Linear-style (flat): var tiles = consumer.acquire() process\_tiles(tiles.payload(), tiles.stage()) tiles.release() # Manual release Lifecycle: 1. Created via `consumer.acquire()` - waits for producer 2. Use `payload()`, `stage()` for tile access 3. Call `release()` or let `__exit__` signal and advance ## Parameters * ​origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​Payload ([`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload)): The tile payload type. * ​num\_group\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of synchronization stages. * ​k\_group\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles per synchronization stage. ## Fields * ​pipeline\_ptr (`Pointer[InputConsumerStage[origin, Payload, num_group_stages, k_group_size].PipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `PipelineType` `comptime PipelineType = InputTilePipeline[Payload, num_group_stages, k_group_size]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[InputConsumerStage[origin, Payload, num_group_stages, k_group_size].PipelineType, origin], stage: UInt32, mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `payload` `payload(self) -> Payload` Get the tile payload for direct access. **Returns:** `Payload` ### `stage` `stage(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `mbar` `mbar(self) -> MbarPtr` Get the barrier pointer. **Returns:** `MbarPtr` ### `release` `release(mut self)` Signal consumption and advance to next stage (linear-style API). Use this for flat code structure instead of context manager. Equivalent to what **exit** does.
--- ## InputProducer
`@register_passable(trivial)` `struct InputProducer[origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int]` Producer view for TMA Load warp. Use acquire() to get stages. ## Fields * ​pipeline\_ptr (`Pointer[InputProducer[origin, Payload, num_group_stages, k_group_size].PipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `PipelineType` `comptime PipelineType = InputTilePipeline[Payload, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `drain` `drain(mut self)` Drain pipeline to prevent CTA exit while peer is still working. ### `acquire` `acquire(mut self) -> InputProducerStage[origin, Payload, num_group_stages, k_group_size]` Acquire next stage, waiting for slot availability. Returns a context manager for loading tiles. **Returns:** [`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage) ### `try_acquire` `try_acquire(mut self) -> Bool` Non-blocking check if next producer stage is available. Use with acquire\_if\_needed() for the try-acquire pattern: ``` var ready = producer.try_acquire() # ... do other work ... with producer.acquire_if_needed(ready) as tiles: load_tiles() ``` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if the stage is ready, False if waiting is needed. ### `acquire_if_needed` `acquire_if_needed(mut self, already_ready: Bool) -> InputProducerStage[origin, Payload, num_group_stages, k_group_size]` Acquire stage, only waiting if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_acquire(). Skips wait if True. **Returns:** [`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage): The producer stage for loading tiles.
--- ## InputProducerStage
`@register_passable(trivial)` `struct InputProducerStage[origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int]` Handle for producer tile access - works as context manager or linear-style. Two usage patterns: 1. Context manager (scoped): with producer.acquire() as tiles: load\_tiles(tiles.payload(), tiles.stage(), tiles.barrier()) # release() called automatically by **exit** 2. Linear-style (flat): var tiles = producer.acquire() load\_tiles(tiles.payload(), tiles.stage(), tiles.barrier()) tiles.release() # Manual release Lifecycle: 1. Created via `producer.acquire()` - waits for consumer 2. Use `payload()`, `stage()`, `barrier()` for TMA operations 3. Call `release()` or let `__exit__` advance producer stage ## Parameters * ​origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​Payload ([`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload)): The tile payload type. * ​num\_group\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of synchronization stages. * ​k\_group\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles per synchronization stage. ## Fields * ​pipeline\_ptr (`Pointer[InputProducerStage[origin, Payload, num_group_stages, k_group_size].PipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `PipelineType` `comptime PipelineType = InputTilePipeline[Payload, num_group_stages, k_group_size]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[InputProducerStage[origin, Payload, num_group_stages, k_group_size].PipelineType, origin], stage: UInt32, barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `payload` `payload(self) -> Payload` Get the tile payload for direct access. **Returns:** `Payload` ### `stage` `stage(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `expect_bytes` `expect_bytes(self, num_bytes: Int)` Set expected bytes on the barrier for TMA loads. ### `barrier` `barrier(self) -> MbarPtr` Get the barrier pointer for TMA multicast loads. **Returns:** `MbarPtr` ### `release` `release(mut self)` Advance producer to next stage (linear-style API). Use this for flat code structure instead of context manager. Equivalent to what **exit** does.
--- ## InputTilePipeline
`@register_passable(trivial)` `struct InputTilePipeline[Payload: TilePayload, num_group_stages: Int, k_group_size: Int]` Tile pipeline with configurable payload type. Separates synchronization from tile storage. The Payload parameter (e.g., StandardTilePayload or BlockScaledTilePayload) holds tile arrays. ## Fields * ​pipeline (`InputTilePipeline[Payload, num_group_stages, k_group_size].Pipeline`): * ​payload (`Payload`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = Payload.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = Payload.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = Payload.__moveinit__is_trivial` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_group_stages * 2)]` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_group_stages]` ## Methods ### `__init__` `__init__(barriers: SMemArray[SharedMemBarrier, (num_group_stages * 2)], payload: Payload) -> Self` Initialize from typed barrier array and payload. ### `init_barriers` `static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize pipeline barriers. Called once by elect\_one thread. ### `acquire_producer` `acquire_producer(mut self) -> Tuple[UInt32, MbarPtr]` Wait for slot availability and return (stage, barrier). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `release_producer` `release_producer(mut self)` Signal completion and advance producer stage. ### `acquire_consumer` `acquire_consumer(mut self) -> Tuple[UInt32, MbarPtr]` Wait for data availability and return (stage, barrier). **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `release_consumer` `release_consumer(mut self)` Signal completion and advance consumer stage. ### `try_acquire_producer` `try_acquire_producer(self) -> Bool` Non-blocking check if next producer stage is available. Example (TMA Load warp): ` var ready = pipeline.try_acquire_producer() # ... do other work while potentially waiting ... pipeline.wait_producer_if_needed(ready) var stage = pipeline.producer_stage() # ... load tiles ... ` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if consumer has freed the stage, False otherwise. ### `try_acquire_consumer` `try_acquire_consumer(self) -> Bool` Non-blocking check if next consumer stage has data. Example (MMA warp): ` var ready = pipeline.try_acquire_consumer() # ... do other work while potentially waiting ... pipeline.wait_consumer_if_needed(ready) var stage = pipeline.consumer_stage() # ... process tiles ... ` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool): True if producer has filled the stage, False otherwise. ### `wait_producer_if_needed` `wait_producer_if_needed(self, already_ready: Bool)` Conditionally wait for producer stage if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_acquire\_consumer(). ### `wait_consumer_if_needed` `wait_consumer_if_needed(self, already_ready: Bool)` Conditionally wait for consumer to free stage if not already ready. **Args:** * ​already\_ready ([`Bool`](/mojo/std/builtin/bool/Bool)): Result from try\_acquire\_producer(). ### `producer_stage` `producer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `consumer_stage` `consumer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr` ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr` ### `producer` `producer[mut_origin: MutOrigin](ref[mut_origin] self) -> InputProducer[mut_origin, Payload, num_group_stages, k_group_size]` Get producer view for TMA Load warp. **Returns:** `InputProducer` ### `consumer` `consumer[mut_origin: MutOrigin](ref[mut_origin] self) -> InputConsumer[mut_origin, Payload, num_group_stages, k_group_size]` Get consumer view for MMA warp. **Returns:** `InputConsumer` ### `acquire_producer_linear` `acquire_producer_linear[mut_origin: MutOrigin](ref[mut_origin] self) -> InputProducerStage[mut_origin, Payload, num_group_stages, k_group_size]` Acquire a producer stage handle using linear types. Waits for the consumer to free the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var tiles = pipeline.acquire\_producer\_linear() load\_tiles(tiles.payload(), tiles.stage(), tiles.barrier()) tiles^.release() # Advances to next stage **Returns:** [`InputProducerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputProducerStage): An InputProducerStage handle that must be released. ### `acquire_consumer_linear` `acquire_consumer_linear[mut_origin: MutOrigin](ref[mut_origin] self) -> InputConsumerStage[mut_origin, Payload, num_group_stages, k_group_size]` Acquire a consumer stage handle using linear types. Waits for the producer to fill the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var tiles = pipeline.acquire\_consumer\_linear() process\_tiles(tiles.payload(), tiles.stage()) tiles^.release() # Signals complete and advances **Returns:** [`InputConsumerStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputConsumerStage): An InputConsumerStage handle that must be released. ### `drain_producer` `drain_producer(mut self)` Drain pipeline to prevent CTA exit while peer is still working. Call this after all producer iterations are complete. This is the linear type equivalent of InputProducer.drain().
--- ## MmaKStage
`@register_passable(trivial)` `struct MmaKStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Per-K stage context for MMA warp in blockwise FP8. **enter**: Acquires stage, waits for epilogue to release previous stage **exit**: Signals mma\_arrive to notify epilogue, advances producer stage ## Fields * ​pipeline\_ptr (`Pointer[MmaKStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): * ​stage (`MmaKStage[origin, num_stages, stage_stride_cols, cta_group].Stage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[MmaKStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> MmaKStage[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `MmaKStage` ### `__exit__` `__exit__(mut self)`
--- ## MmaStage
`struct MmaStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Unified linear type handle for MMA stage in output pipeline. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `output_pipeline.acquire_mma_linear()` - waits for epilogue 2. Use `tmem()`, `tmem_offset()`, `mbar()` for MMA operations 3. Must call `release()` to signal mma\_arrive and advance (compiler-enforced) ## Parameters * ​origin ([`MutOrigin`](/mojo/std/builtin/type_aliases/#mutorigin)): Origin of the pipeline reference. * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of pipeline stages. * ​stage\_stride\_cols ([`Int`](/mojo/std/builtin/int/Int)): TMEM column stride between stages. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). ## Fields * ​pipeline\_ptr (`Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(out self, pipeline_ptr: Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin], stage: OutputStage[num_stages, stage_stride_cols, cta_group])` ### `tmem` `tmem(self) -> MmaStage[origin, num_stages, stage_stride_cols, cta_group].Stage.Tmem` Get the TMEM stage handle. **Returns:** `MmaStage` ### `tmem_offset` `tmem_offset(self) -> Int` Get the TMEM offset for MMA accumulator. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `mbar` `mbar(self) -> MbarPtr` Get the producer barrier for MMA commit. **Returns:** `MbarPtr` ### `release` `release(deinit self)` Signal MMA completion and advance to next stage. This is the only way to destroy this linear type. Internally calls mma\_arrive (1-SM) or mma\_arrive\_multicast (2-SM).
--- ## OutputConsumer
`@register_passable(trivial)` `struct OutputConsumer[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Consumer view for epilogue warp (output pipeline). ## Fields * ​pipeline\_ptr (`Pointer[OutputConsumer[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[OutputConsumer[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> OutputConsumer[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `OutputConsumer` ### `__exit__` `__exit__(mut self)`
--- ## OutputKPipeline
`@register_passable(trivial)` `struct OutputKPipeline[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Per-K-iteration view of OutputTilePipeline. Unlike standard producer()/consumer() which signal once per tile (after all K iterations), this view signals after each K iteration. Use for kernels with per-K accumulation patterns (e.g., blockwise FP8). Example (MMA warp): for i in range(num\_iters): with mma\_ctx.output\_pipeline.per\_k().produce() as stage: mma(stage.tmem, ...) \# **exit** signals mma\_arrive for this K iteration Example (Epilogue warp): for k\_iter in range(num\_iters): with epi\_ctx.output\_pipeline.per\_k().consume() as stage: promote(stage.tmem, ...) \# **exit** signals consumer\_step for this K iteration ## Fields * ​pipeline\_ptr (`Pointer[OutputKPipeline[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[OutputKPipeline[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self` ### `produce` `produce(self) -> MmaKStage[origin, num_stages, stage_stride_cols, cta_group]` Get MMA stage context manager for one K iteration. **Returns:** `MmaKStage`: Context manager that acquires stage on enter and signals mma\_arrive on exit. ### `consume` `consume(self) -> PerKConsumerStage[origin, num_stages, stage_stride_cols, cta_group]` Get consumer context manager for one K iteration. **Returns:** `PerKConsumerStage`: Context manager that waits for MMA on enter and signals consumer\_step on exit.
--- ## OutputProducer
`@register_passable(trivial)` `struct OutputProducer[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Producer view for MMA warp (output pipeline). ## Fields * ​pipeline\_ptr (`Pointer[OutputProducer[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): * ​stage (`OutputProducer[origin, num_stages, stage_stride_cols, cta_group].Stage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[OutputProducer[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> OutputProducer[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `OutputProducer` ### `__exit__` `__exit__(mut self)`
--- ## OutputStage
`@register_passable(trivial)` `struct OutputStage[num_stages: Int, stage_stride: Int, cta_group: Int]` Acquired output stage with TMEM handle and pipeline reference. ## Fields * ​index (`UInt32`): * ​tmem (`OutputStage[num_stages, stage_stride, cta_group].Tmem`): * ​pipeline (`OutputStage[num_stages, stage_stride, cta_group].Pipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_stages]` ### `Tmem` `comptime Tmem = TmemStage[num_stages, stage_stride, cta_group]` ## Methods ### `__init__` `__init__(index: UInt32, tmem: TmemStage[num_stages, stage_stride, cta_group], pipeline: ProducerConsumerPipeline[num_stages]) -> Self` ### `from_raw` `static from_raw(pipeline: ProducerConsumerPipeline[num_stages], stage_index: UInt32, tmem_offset: UInt32) -> Self` Create OutputStage from raw pipeline, stage index, and TMEM offset. Useful when not using OutputTilePipeline's consumer() context manager. **Args:** * ​pipeline ([`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline)): The ProducerConsumerPipeline for barrier signaling. * ​stage\_index ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Current pipeline stage index. * ​tmem\_offset ([`UInt32`](/mojo/std/builtin/simd/#uint32)): Pre-computed TMEM offset for this stage. **Returns:** `Self`: OutputStage with the given parameters.
--- ## OutputTilePipeline
`@register_passable(trivial)` `struct OutputTilePipeline[num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Pipeline for MMA→Epilogue TMEM stage synchronization. ## Fields * ​pipeline (`OutputTilePipeline[num_stages, stage_stride_cols, cta_group].Pipeline`): * ​tmem (`OutputTilePipeline[num_stages, stage_stride_cols, cta_group].Tmem`): * ​mma\_complete\_mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_stages]` ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `Tmem` `comptime Tmem = TmemAllocation[cta_group]` ## Methods ### `__init__` `__init__(barriers: SMemArray[SharedMemBarrier, (num_stages * 2)], tmem: TmemAllocation[cta_group], mma_complete_mask: UInt16) -> Self` Initialize from barrier array, TMEM allocation, and multicast mask. ### `init_barriers` `static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize pipeline barriers. Called once by elect\_one thread. ### `acquire_for_mma` `acquire_for_mma(self) -> OutputTilePipeline[num_stages, stage_stride_cols, cta_group].Stage` Acquire stage for MMA, waiting for epilogue to finish. **Returns:** `OutputTilePipeline` ### `release_from_mma` `release_from_mma(mut self, stage: OutputStage[num_stages, stage_stride_cols, cta_group])` Signal MMA completion using mma\_arrive (1-SM) or multicast (2-SM). ### `acquire_for_epilogue` `acquire_for_epilogue(self) -> OutputTilePipeline[num_stages, stage_stride_cols, cta_group].Stage` Acquire stage for epilogue, waiting for MMA to complete. **Returns:** `OutputTilePipeline` ### `release_from_epilogue` `release_from_epilogue(mut self)` Signal epilogue completion, freeing stage for MMA reuse. ### `producer` `producer[origin: MutOrigin, //](ref[origin] self) -> OutputProducer[origin, num_stages, stage_stride_cols, cta_group]` Get producer view for MMA warp. **Returns:** `OutputProducer` ### `consumer` `consumer[origin: MutOrigin, //](ref[origin] self) -> OutputConsumer[origin, num_stages, stage_stride_cols, cta_group]` Get consumer view for epilogue warp. **Returns:** `OutputConsumer` ### `acquire_mma_linear` `acquire_mma_linear[origin: MutOrigin, //](ref[origin] self) -> MmaStage[origin, num_stages, stage_stride_cols, cta_group]` Acquire a stage for MMA using linear types. Waits for the epilogue to free the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var stage = output\_pipeline.acquire\_mma\_linear() mma\_op.mma(a\_tile, b\_tile, stage.tmem\_offset()) mma\_op.commit(stage.mbar()) stage^.release() # Signals mma\_arrive and advances **Returns:** `MmaStage`: An MmaStage handle that must be released. ### `acquire_epilogue_linear` `acquire_epilogue_linear[origin: MutOrigin, //](ref[origin] self) -> EpilogueStage[origin, num_stages, stage_stride_cols, cta_group]` Acquire a stage for epilogue using linear types. Waits for MMA to complete the current stage, then returns a linear type handle that MUST be released (compiler-enforced). Usage: var stage = output\_pipeline.acquire\_epilogue\_linear() process\_tmem(stage.tmem()) stage^.release() # Advances to next stage **Returns:** `EpilogueStage`: An EpilogueStage handle that must be released. ### `get_pipeline` `get_pipeline(self) -> OutputTilePipeline[num_stages, stage_stride_cols, cta_group].Pipeline` Get underlying pipeline (used during barrier initialization). **Returns:** `OutputTilePipeline` ### `per_k` `per_k[origin: MutOrigin, //](ref[origin] self) -> OutputKPipeline[origin, num_stages, stage_stride_cols, cta_group]` Get per-K-iteration view for kernels with per-K signaling. Unlike producer()/consumer() which signal once per tile (after all K iterations), this view signals after each K iteration. Use for kernels with per-K accumulation patterns (e.g., blockwise FP8). **Returns:** `OutputKPipeline`: OutputKPipeline view that provides produce()/consume() context managers for per-K-iteration barrier signaling. ### `per_k_epilogue` `per_k_epilogue[output_origin: MutOrigin, input_origin: MutOrigin, num_input_stages: Int](ref[output_origin] self, ref[input_origin] input_pipeline: ProducerConsumerPipeline[num_input_stages]) -> EpilogueKContext[output_origin, input_origin, num_stages, stage_stride_cols, cta_group, num_input_stages]` Get combined per-K epilogue context for blockwise FP8. Bundles output pipeline (MMA->Epilogue sync) and input pipeline (A-scales consumption) into a single context manager. Example: for k\_iter in range(num\_iters): with output\_pipeline.per\_k\_epilogue(input\_pipeline) as stage: accum.promote(stage, ...) \# Both pipelines signaled automatically **Args:** * ​input\_pipeline ([`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline)): The input pipeline for A-scales consumption. **Returns:** `EpilogueKContext`: EpilogueKContext context manager that handles both pipelines.
--- ## PerKConsumerStage
`@register_passable(trivial)` `struct PerKConsumerStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Context manager for per-K epilogue consumption. **enter**: Acquires stage, waits for MMA to complete this K iteration **exit**: Signals consumer barrier to release stage for MMA reuse IMPORTANT: Unlike standard per-tile consumption, per-K consumption must signal the consumer barrier explicitly. The MMA warp waits on this barrier before each K iteration, so we must signal after each K iteration. ## Fields * ​pipeline\_ptr (`Pointer[PerKConsumerStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]`): * ​stage (`PerKConsumerStage[origin, num_stages, stage_stride_cols, cta_group].Stage`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Stage` `comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]` ### `TilePipelineType` `comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[PerKConsumerStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> PerKConsumerStage[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `PerKConsumerStage` ### `__exit__` `__exit__(mut self)`
--- ## StandardConsumerStage
`@register_passable(trivial)` `struct StandardConsumerStage[origin: MutOrigin, a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Context manager for consumer tile access with encapsulated stage indexing. ## Fields * ​pipeline\_ptr (`Pointer[StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATile` ### `ATileArray` `comptime ATileArray = StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATileArray` ### `BTile` `comptime BTile = StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTile` ### `BTileArray` `comptime BTileArray = StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTileArray` ### `TilePipelineType` `comptime TilePipelineType = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin], stage: UInt32, mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `get_tile` `get_tile(self, k_idx: Int) -> Tuple[StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATile, StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTile]` Get A and B tiles at the specified k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile(self, k_idx: Int) -> StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATile` Get A tile at the specified k-group index. **Returns:** `StandardConsumerStage` ### `get_b_tile` `get_b_tile(self, k_idx: Int) -> StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTile` Get B tile at the specified k-group index. **Returns:** `StandardConsumerStage` ### `mbar` `mbar(self) -> MbarPtr` Get the barrier pointer for MMA commit. **Returns:** `MbarPtr` ### `stage` `stage(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32)
--- ## StandardProducerStage
`@register_passable(trivial)` `struct StandardProducerStage[origin: MutOrigin, a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Context manager for producer tile access with encapsulated stage indexing. ## Fields * ​pipeline\_ptr (`Pointer[StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATile` ### `ATileArray` `comptime ATileArray = StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATileArray` ### `BTile` `comptime BTile = StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTile` ### `BTileArray` `comptime BTileArray = StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTileArray` ### `TilePipelineType` `comptime TilePipelineType = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__init__` `__init__(pipeline_ptr: Pointer[StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin], stage: UInt32, barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `get_tile` `get_tile(self, k_idx: Int) -> Tuple[StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATile, StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTile]` Get A and B tiles at the specified k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile(self, k_idx: Int) -> StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATile` Get A tile at the specified k-group index. **Returns:** `StandardProducerStage` ### `get_b_tile` `get_b_tile(self, k_idx: Int) -> StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTile` Get B tile at the specified k-group index. **Returns:** `StandardProducerStage` ### `expect_bytes` `expect_bytes(self, num_bytes: Int)` Set expected bytes on the barrier for TMA loads. ### `barrier` `barrier(self) -> MbarPtr` Get the barrier pointer for TMA multicast loads. **Returns:** `MbarPtr` ### `stage` `stage(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32)
--- ## StandardTileConsumer
`@register_passable(trivial)` `struct StandardTileConsumer[origin: MutOrigin, a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Consumer view for MMA warp (standard tile pipeline). ## Fields * ​pipeline\_ptr (`Pointer[StandardTileConsumer[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TilePipelineType` `comptime TilePipelineType = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `acquire` `acquire(mut self) -> StandardConsumerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` Acquire next stage, waiting for tiles to be ready. **Returns:** `StandardConsumerStage`
--- ## StandardTilePayload
`@register_passable(trivial)` `struct StandardTilePayload[a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int]` Tile payload for standard matmul (A and B tiles). Uses explicit dimensions for tile arrays. The tiles are stored as TileTensor with row\_major layout and converted to LayoutTensor with swizzled layouts at TMA/MMA boundaries. ## Fields * ​a\_tiles (`StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATileArray`): * ​b\_tiles (`StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTileArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTile` `comptime BTile = StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ## Methods ### `__init__` `__init__(a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]) -> Self` ### `get_tile` `get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATile, StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTile]` Get A and B tiles at the specified stage and k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].ATile` Get A tile at the specified stage and k-group index. **Returns:** `StandardTilePayload` ### `get_b_tile` `get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> StandardTilePayload[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages].BTile` Get B tile at the specified stage and k-group index. **Returns:** `StandardTilePayload`
--- ## StandardTileProducer
`@register_passable(trivial)` `struct StandardTileProducer[origin: MutOrigin, a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Producer view for TMA Load warp (standard tile pipeline). ## Fields * ​pipeline\_ptr (`Pointer[StandardTileProducer[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TilePipelineType` `comptime TilePipelineType = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `drain` `drain(mut self)` Drain pipeline to prevent CTA exit while peer is still working. ### `acquire` `acquire(mut self) -> StandardProducerStage[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` Acquire next stage, waiting for slot availability. **Returns:** `StandardProducerStage`
--- ## TilePayload
Trait for tile payload types. Must be extend TrivialRegisterPassable. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## TilePipeline
`@register_passable(trivial)` `struct TilePipeline[a_type: DType, b_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Staged tile storage with producer-consumer synchronization for SM100. Manages a fixed set of pipeline stages (not a FIFO queue) where: * Producer (TMA Load) fills tiles into the current stage * Consumer (MMA) reads tiles from the current stage * Barriers coordinate access between producer and consumer Template Parameters: a\_type: Data type for A matrix tiles. b\_type: Data type for B matrix tiles. a\_dim0: First dimension for A tiles. a\_dim1: Second dimension for A tiles. b\_dim0: First dimension for B tiles. b\_dim1: Second dimension for B tiles. num\_pipeline\_stages: Total number of tile stages (stages \* k\_group\_size). num\_group\_stages: Number of synchronization stages. k\_group\_size: Number of tiles per synchronization stage. ## Fields * ​pipeline (`TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].Pipeline`): * ​a\_tiles (`TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATileArray`): * ​b\_tiles (`TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTileArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BarrierArray` `comptime BarrierArray = SMemArray[SharedMemBarrier, (num_group_stages * 2)]` ### `BTile` `comptime BTile = TilePipeline[a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_group_stages]` ## Methods ### `__init__` `__init__(barriers: SMemArray[SharedMemBarrier, (num_group_stages * 2)], a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]) -> Self` Initialize from typed barrier array and tile arrays. ### `init_barriers` `static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize pipeline barriers. Called once by elect\_one thread. ### `producer` `producer[origin: MutOrigin](ref[origin] self) -> StandardTileProducer[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` Get producer view for TMA Load warp. **Returns:** `StandardTileProducer` ### `consumer` `consumer[origin: MutOrigin](ref[origin] self) -> StandardTileConsumer[origin, a_type, b_type, a_dim0, a_dim1, b_dim0, b_dim1, num_pipeline_stages, num_group_stages, k_group_size]` Get consumer view for MMA warp. **Returns:** `StandardTileConsumer` ### `producer_stage` `producer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `consumer_stage` `consumer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr` ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr`
--- ## tile_pipeline
Tile pipeline for SM100 producer-consumer synchronization. Provides staged tile storage with producer-consumer barrier synchronization for TMA-MMA pipeline coordination. All barrier operations are encapsulated in context managers for safety and clarity. All tiles use TileTensor natively. Convert to LayoutTensor at TMA/MMA boundaries using {ptr} syntax or explicit LayoutTensor construction. ## Key Abstractions * InputTilePipeline\[Payload]: Generic pipeline with payload abstraction * TilePipeline: Standard pipeline with explicit A/B tile types * OutputTilePipeline: TMEM accumulator stages for MMA→Epilogue pipeline ## Naming Conventions * \*Pipeline: Multi-stage buffer (InputTilePipeline, OutputTilePipeline) * \*Producer/\*Consumer: Role handles (InputProducer, OutputConsumer) * acquire(): Context manager to get one pipeline stage ## Context Manager Semantics Each `with` block handles barrier synchronization automatically: ``` with producer.acquire() as tiles: # BLOCKS until consumer releases stage load_tiles(tiles) # safe to write # EXIT: signals producer barrier, advances with consumer.acquire() as tiles: # BLOCKS until producer fills stage use_tiles(tiles) # safe to read # EXIT: signals consumer barrier, advances ``` ## Example: TMA Load Warp (Producer) ``` with input_pipeline.producer() as producer: # producer role for this warp while work_iter.has_work(): with work_iter.next() as current: for i in range(num_iters): with producer.acquire() as tiles: # waits for consumer tma_load(tiles.a_tile(), tiles.b_tile()) producer.drain() # wait for all stages consumed before CTA exits ``` ## Example: MMA Warp (Consumer + Output Producer) ``` with mma_ctx: # TMEM lifecycle while work_iter.has_work(): with work_iter.wait_and_advance(): # blocks on CLC response with output_pipeline.producer() as output_stage: # waits for epilogue with input_pipeline.consumer() as consumer: for i in range(num_iters): with consumer.acquire() as input_tiles: # waits for TMA mma(output_stage.tmem, input_tiles) ``` ## Example: Epilogue Warp (Output Consumer) ``` with epi_ctx: # signals TMEM dealloc on exit while work_iter.has_work(): with work_iter.next() as current: with output_pipeline.consumer() as output_stage: # waits for MMA write_output(output_stage) ``` ## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ## Structs * [​`BlockScaledTilePayload`](./BlockScaledTilePayload): Tile payload for block-scaled matmul (A, B, SFA, SFB tiles). * [​`BlockwiseFP8TilePayload`](./BlockwiseFP8TilePayload): Tile payload for blockwise FP8 matmul (A, B, A-scales tiles). * [​`EpilogueKContext`](./EpilogueKContext): Per-K context manager for epilogue warp in blockwise FP8. * [​`EpilogueKStage`](./EpilogueKStage): Per-K stage for epilogue warp in blockwise FP8. * [​`EpilogueStage`](./EpilogueStage): Unified linear type handle for epilogue stage in output pipeline. * [​`InputConsumer`](./InputConsumer): Consumer view for MMA warp. Use acquire() to get stages. * [​`InputConsumerStage`](./InputConsumerStage): Handle for consumer tile access - works as context manager or linear-style. * [​`InputProducer`](./InputProducer): Producer view for TMA Load warp. Use acquire() to get stages. * [​`InputProducerStage`](./InputProducerStage): Handle for producer tile access - works as context manager or linear-style. * [​`InputTilePipeline`](./InputTilePipeline): Tile pipeline with configurable payload type. * [​`MmaKStage`](./MmaKStage): Per-K stage context for MMA warp in blockwise FP8. * [​`MmaStage`](./MmaStage): Unified linear type handle for MMA stage in output pipeline. * [​`OutputConsumer`](./OutputConsumer): Consumer view for epilogue warp (output pipeline). * [​`OutputKPipeline`](./OutputKPipeline): Per-K-iteration view of OutputTilePipeline. * [​`OutputProducer`](./OutputProducer): Producer view for MMA warp (output pipeline). * [​`OutputStage`](./OutputStage): Acquired output stage with TMEM handle and pipeline reference. * [​`OutputTilePipeline`](./OutputTilePipeline): Pipeline for MMA→Epilogue TMEM stage synchronization. * [​`PerKConsumerStage`](./PerKConsumerStage): Context manager for per-K epilogue consumption. * [​`StandardConsumerStage`](./StandardConsumerStage): Context manager for consumer tile access with encapsulated stage indexing. * [​`StandardProducerStage`](./StandardProducerStage): Context manager for producer tile access with encapsulated stage indexing. * [​`StandardTileConsumer`](./StandardTileConsumer): Consumer view for MMA warp (standard tile pipeline). * [​`StandardTilePayload`](./StandardTilePayload): Tile payload for standard matmul (A and B tiles). * [​`StandardTileProducer`](./StandardTileProducer): Producer view for TMA Load warp (standard tile pipeline). * [​`TilePipeline`](./TilePipeline): Staged tile storage with producer-consumer synchronization for SM100. ## Traits * [​`TilePayload`](./TilePayload): Trait for tile payload types. Must be extend TrivialRegisterPassable.
--- ## AdvanceAfterWorkContext
`@register_passable(trivial)` `struct AdvanceAfterWorkContext[work_origin: MutOrigin, state_origin: MutOrigin, num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * **enter**: Returns current work\_info for use in the block * **exit**: Fetches next work, assigns to work\_info, steps state ## Fields * ​scheduler (`AdvanceAfterWorkContext[work_origin, state_origin, num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​consumer\_state\_ptr (`Pointer[PipelineState[num_stages], state_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info_ptr: Pointer[WorkInfo, work_origin], consumer_state_ptr: Pointer[PipelineState[num_stages], state_origin]) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## SchedulerWorkIterator
`@register_passable(trivial)` `struct SchedulerWorkIterator[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Work iterator for Scheduler warp - owns work\_info and both pipeline states. The Scheduler warp uniquely needs to: 1. Consume work responses (like other warps) via next() 2. Signal throttle and produce new work requests via signal\_and\_advance() 3. Drain pending requests at exit via drain() Usage: var sched\_iter = scheduler.scheduler\_iterator() while sched\_iter.has\_work(): with sched\_iter.next(): sched\_iter.signal\_and\_advance() sched\_iter.drain() ## Fields * ​scheduler (`SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​producer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ### `ThrottlePipeline` `comptime ThrottlePipeline = SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info: WorkInfo) -> Self` Create scheduler iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> AdvanceAfterWorkContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Get next work item. **Returns:** `AdvanceAfterWorkContext` ### `signal_and_advance` `signal_and_advance(mut self)` Signal CLC throttle consumer and advance to next work request. Combines two operations that always happen together in Scheduler warp: 1. Signal throttle consumer (tells Load warp we've consumed a response) 2. Issue next CLC work request (producer side) ### `drain` `drain(mut self)` Drain all pending CLC requests before kernel exit. Must be called after the work loop completes to ensure all CLC pipeline stages are properly synchronized before exit.
--- ## TileScheduler (3)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8]` ## Fields * ​cluster\_dim (`StaticTuple[Int32, 3]`): * ​log\_cluster\_dim\_m (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_n (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_k (`FastDiv[DType.uint32]`): * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​throttle\_pipeline (`TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ClcBarrierArray` `comptime ClcBarrierArray = SMemArray[SharedMemBarrier, num_stages]` ### `ClcResponseArray` `comptime ClcResponseArray = SMemArray[UInt128, num_stages]` ### `cluster_size` `comptime cluster_size = ((cluster_shape.__getitem__[3, DType.uint32, Int](0) * cluster_shape.__getitem__[3, DType.uint32, Int](1)) * cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_k` `comptime log_cluster_k = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_m` `comptime log_cluster_m = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](0))` ### `log_cluster_n` `comptime log_cluster_n = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](1))` ### `ThrottleBarrierArray` `comptime ThrottleBarrierArray = SMemArray[SharedMemBarrier, (num_stages * 2)]` ### `ThrottlePipeline` `comptime ThrottlePipeline = ProducerConsumerPipeline[num_stages]` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], clc_response: SMemArray[UInt128, num_stages], clc_full: SMemArray[SharedMemBarrier, num_stages], clc_empty: SMemArray[SharedMemBarrier, num_stages], clc_throttle: SMemArray[SharedMemBarrier, (num_stages * 2)]) -> Self` Initialize from typed barrier arrays. ### `init_throttle_barriers` `static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize throttle pipeline barriers. Called once by elect\_one thread. ### `work_info_from_clc_response` `static work_info_from_clc_response(result: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `work_info_from_cluster` `static work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, 3], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_after_work` `advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref[work_origin] work_info: WorkInfo, ref[state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContext[work_origin, state_origin, num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). Usage: with scheduler.advance\_after\_work(work\_info, state) as current: do\_work(current) syncwarp() \# After: work\_info updated, state stepped **Returns:** `AdvanceAfterWorkContext` ### `wait_and_advance_work` `wait_and_advance_work[work_origin: MutOrigin, //](self, ref[work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> WaitAndAdvanceContext[work_origin]` Wait for next work from CLC and advance. Encapsulates the CLC barrier wait (called on scheduler directly). Usage: with scheduler.wait\_and\_advance\_work(work\_info, state) as current: do\_mma(current) \# After: work\_info updated to next value **Returns:** `WaitAndAdvanceContext` ### `work_iterator` `work_iterator(self) -> WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Create a per-warp work iterator with internally managed state. Each warp should create its own work iterator. The iterator owns work\_info, pipeline state, and throttle internally. Usage: var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): with work\_iter.next() as current: work\_iter.throttle\_signal(ctx.is\_first\_cta\_in\_cluster) do\_work(current) **Returns:** `WorkIterator` ### `scheduler_iterator` `scheduler_iterator(self) -> SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Create iterator for Scheduler warp (owns work\_info and both pipeline states). The Scheduler warp uniquely needs to both consume work responses and produce new work requests. This iterator owns everything internally. Usage: var sched\_iter = scheduler.scheduler\_iterator() while sched\_iter.has\_work(): with sched\_iter.next(): sched\_iter.signal\_and\_advance() sched\_iter.drain() **Returns:** `SchedulerWorkIterator` ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState)
--- ## WaitAndAdvanceContext
`@register_passable(trivial)` `struct WaitAndAdvanceContext[work_origin: MutOrigin]` Context for waiting on CLC barrier and advancing work iterator. Encapsulates the CLC response barrier synchronization: * Construction: Waits for CLC response, fetches next work * **enter**: Returns current work\_info for processing * **exit**: Assigns fetched work as current Usage: with work\_iter.wait\_and\_advance() as current: \# current is the work item to process NOW process(current) \# After exit, work\_iter.work\_info is the NEXT work item ## Fields * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​next\_work (`WorkInfo`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[WorkInfo, work_origin], next_work: WorkInfo) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## WorkInfo (3)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `coord` `coord(self) -> Tuple[UInt, UInt]` Get (m, n) tile coordinates as a tuple. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## WorkIterator
`@register_passable(trivial)` `struct WorkIterator[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Per-warp work iterator that owns work\_info and pipeline state. Each warp creates its own WorkIterator which internally manages both the current work item and the CLC pipeline consumer state. Throttle pipeline is obtained from the scheduler. Usage: var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): with work\_iter.next() as current: work\_iter.throttle\_signal(ctx.is\_first\_cta\_in\_cluster) do\_work(current) ## Fields * ​scheduler (`WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ### `ThrottlePipeline` `comptime ThrottlePipeline = WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info: WorkInfo) -> Self` Create work iterator with initial work\_info. Throttle from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> AdvanceAfterWorkContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Get next work item (advance AFTER work pattern). **Returns:** `AdvanceAfterWorkContext` ### `wait_and_advance` `wait_and_advance[state_origin: MutOrigin, //](ref[state_origin] self) -> WaitAndAdvanceContext[origin_of(state_origin._mlir_origin.work_info)]` Wait for next work from CLC and advance iterator. Encapsulates the CLC barrier wait: * **enter**: Waits for CLC response, returns current work * **exit**: Assigns fetched work as current Usage: with work\_iter.wait\_and\_advance() as current: \# Process current work item \# After exit, work\_iter points to next work **Returns:** `WaitAndAdvanceContext` ### `throttle_signal` `throttle_signal(mut self, is_first_cta_in_cluster: Bool)` Signal CLC throttle if this is the first CTA in cluster. The Load warp acts as producer for CLC throttle, signaling that it has started processing a new work item. This prevents the scheduler from getting too far ahead. **Args:** * ​is\_first\_cta\_in\_cluster ([`Bool`](/mojo/std/builtin/bool/Bool)): Only first CTA signals to avoid duplicates.
--- ## tile_scheduler (Tile_scheduler)
## Structs * [​`AdvanceAfterWorkContext`](./AdvanceAfterWorkContext): Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * [​`SchedulerWorkIterator`](./SchedulerWorkIterator): Work iterator for Scheduler warp - owns work\_info and both pipeline states. * [​`TileScheduler`](./TileScheduler): * [​`WaitAndAdvanceContext`](./WaitAndAdvanceContext): Context for waiting on CLC barrier and advancing work iterator. * [​`WorkInfo`](./WorkInfo): * [​`WorkIterator`](./WorkIterator): Per-warp work iterator that owns work\_info and pipeline state.
--- ## AdvanceAfterWorkContextSplitK
`@register_passable(trivial)` `struct AdvanceAfterWorkContextSplitK[work_origin: MutOrigin, state_origin: MutOrigin, num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). ## Fields * ​scheduler (`AdvanceAfterWorkContextSplitK[work_origin, state_origin, num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​consumer\_state\_ptr (`Pointer[PipelineState[num_stages], state_origin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info_ptr: Pointer[WorkInfo, work_origin], consumer_state_ptr: Pointer[PipelineState[num_stages], state_origin]) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## SchedulerWorkIteratorSplitK
`@register_passable(trivial)` `struct SchedulerWorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Work iterator for Scheduler warp (split-K) - owns work\_info and both states. Throttle pipeline is obtained from the scheduler. ## Fields * ​scheduler (`SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​producer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ### `ThrottlePipeline` `comptime ThrottlePipeline = SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self` Create scheduler iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Get next work item. **Returns:** `AdvanceAfterWorkContextSplitK` ### `signal_and_advance` `signal_and_advance(mut self)` Signal CLC throttle consumer and advance to next work request. ### `drain` `drain(mut self)` Drain all pending CLC requests before kernel exit.
--- ## TileScheduler (Tile_scheduler_splitk)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]` ## Fields * ​locks\_ptr (`LegacyUnsafePointer[Int32]`): * ​scheduler (`TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler`): * ​total\_k\_tiles (`UInt32`): * ​k\_tiles\_per\_split (`UInt32`): * ​throttle\_pipeline (`TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BK` `comptime BK = reduction_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = reduction_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `ClcBarrierArray` `comptime ClcBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcBarrierArray` ### `ClcResponseArray` `comptime ClcResponseArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcResponseArray` ### `MMA_N` `comptime MMA_N = reduction_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `ROW_SIZE` `comptime ROW_SIZE = reduction_tile_shape.__getitem__[3, DType.int64, Int](1) if (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].BM == 128)._mlir_value else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)` ### `ThrottleBarrierArray` `comptime ThrottleBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottleBarrierArray` ### `ThrottlePipeline` `comptime ThrottlePipeline = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottlePipeline` ### `UnderlyingScheduler` `comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response: SMemArray[UInt128, num_stages], clc_full: SMemArray[SharedMemBarrier, num_stages], clc_empty: SMemArray[SharedMemBarrier, num_stages], clc_throttle: SMemArray[SharedMemBarrier, (num_stages * 2)], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self` Initialize from typed barrier arrays. ### `init_throttle_barriers` `static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize throttle pipeline barriers. Called once by elect\_one thread. ### `convert_to_splitk_work_info` `convert_to_splitk_work_info(self, work_info: WorkInfo) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_after_work` `advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref[work_origin] work_info: WorkInfo, ref[state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContextSplitK[work_origin, state_origin, num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). Usage: with scheduler.advance\_after\_work(work\_info, state) as current: do\_work(current) syncwarp() \# After: work\_info updated, state stepped **Returns:** `AdvanceAfterWorkContextSplitK` ### `wait_and_advance_work` `wait_and_advance_work[work_origin: MutOrigin, //](self, ref[work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> WaitAndAdvanceContextSplitK[work_origin]` Wait for next work from CLC and advance (Split-K). Encapsulates the CLC barrier wait (called on scheduler directly). Usage: with scheduler.wait\_and\_advance\_work(work\_info, state) as current: do\_mma(current) \# After: work\_info updated to next value **Returns:** `WaitAndAdvanceContextSplitK` ### `work_iterator` `work_iterator(self) -> WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Create a per-warp work iterator that owns work\_info internally. Throttle pipeline is obtained from the scheduler. **Returns:** `WorkIteratorSplitK` ### `scheduler_iterator` `scheduler_iterator(self) -> SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Create iterator for Scheduler warp (owns work\_info and both states). Throttle pipeline is obtained from the scheduler. **Returns:** `SchedulerWorkIteratorSplitK` ### `is_last_split` `is_last_split(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `output_tile_index` `output_tile_index(self, work_info: WorkInfo) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `store_to_workspace` `store_to_workspace[accum_type: DType, workspace_layout: Layout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem: TmemAddress, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], epilogue_thread_idx: Scalar[DType.uint], reduction_tile_idx: UInt32)` ### `reduction` `reduction[accum_type: DType, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], tmem: TmemAddress, epilogue_thread_idx: Scalar[DType.uint], work_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `wait_eq` `static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)` ### `wait_lt` `static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)` ### `arrive_set` `static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)`
--- ## WaitAndAdvanceContextSplitK
`@register_passable(trivial)` `struct WaitAndAdvanceContextSplitK[work_origin: MutOrigin]` Context for waiting on CLC barrier and advancing work iterator (Split-K). Encapsulates the CLC response barrier synchronization: * Construction: Waits for CLC response, fetches next work * **enter**: Returns current work\_info for processing * **exit**: Assigns fetched work as current ## Fields * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​next\_work (`WorkInfo`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[WorkInfo, work_origin], next_work: WorkInfo) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## WorkInfo (Tile_scheduler_splitk)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​num\_k\_tiles (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `INVALID_WORK_INFO` `comptime INVALID_WORK_INFO = WorkInfo(0, 0, 0, 0, False)` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_final_split` `is_final_split(self, k_tiles_per_output_tile: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## WorkIteratorSplitK
`@register_passable(trivial)` `struct WorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Per-warp work iterator for split-K that owns work\_info and pipeline state. Throttle pipeline is obtained from the scheduler. ## Fields * ​scheduler (`WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ### `ThrottlePipeline` `comptime ThrottlePipeline = WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self` Create work iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref[state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Get next work item (advance AFTER work pattern). **Returns:** `AdvanceAfterWorkContextSplitK` ### `wait_and_advance` `wait_and_advance[state_origin: MutOrigin, //](ref[state_origin] self) -> WaitAndAdvanceContextSplitK[origin_of(state_origin._mlir_origin.work_info)]` Wait for next work from CLC and advance iterator (Split-K). Encapsulates the CLC barrier wait: * **enter**: Waits for CLC response, returns current work * **exit**: Assigns fetched work as current **Returns:** `WaitAndAdvanceContextSplitK` ### `throttle_signal` `throttle_signal(mut self, is_first_cta_in_cluster: Bool)` Signal CLC throttle if this is the first CTA in cluster. **Args:** * ​is\_first\_cta\_in\_cluster ([`Bool`](/mojo/std/builtin/bool/Bool)): Only first CTA signals to avoid duplicates.
--- ## get_num_tiles
`get_num_tiles(problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## get_required_locks_buffer_size_bytes
`get_required_locks_buffer_size_bytes[accum_type: DType](problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## tile_scheduler_splitk
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`AdvanceAfterWorkContextSplitK`](./AdvanceAfterWorkContextSplitK): Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * [​`SchedulerWorkIteratorSplitK`](./SchedulerWorkIteratorSplitK): Work iterator for Scheduler warp (split-K) - owns work\_info and both states. Throttle pipeline is obtained from the scheduler. * [​`TileScheduler`](./TileScheduler): * [​`WaitAndAdvanceContextSplitK`](./WaitAndAdvanceContextSplitK): Context for waiting on CLC barrier and advancing work iterator (Split-K). * [​`WorkInfo`](./WorkInfo): * [​`WorkIteratorSplitK`](./WorkIteratorSplitK): Per-warp work iterator for split-K that owns work\_info and pipeline state. Throttle pipeline is obtained from the scheduler. ## Functions * [​`get_num_tiles`](./get_num_tiles): * [​`get_required_locks_buffer_size_bytes`](./get_required_locks_buffer_size_bytes):
--- ## BlockwiseFP8TilePayload (Tile_types)
`@register_passable(trivial)` `struct BlockwiseFP8TilePayload[a_type: DType, b_type: DType, a_scales_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, a_scales_dim0: Int, a_scales_dim1: Int, num_pipeline_stages: Int]` TileTensor-based tile payload for blockwise FP8 matmul. Unlike BlockScaledTilePayload, this only stores A-scales in SMEM. B-scales are read directly from global memory during the epilogue phase. ## Parameters * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A matrix tiles. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for B matrix tiles. * ​a\_scales\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type for A scale tiles. * ​a\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A tiles. * ​a\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A tiles. * ​b\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for B tiles. * ​b\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for B tiles. * ​a\_scales\_dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension for A scale tiles. * ​a\_scales\_dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension for A scale tiles. * ​num\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of input pipeline stages. ## Fields * ​a\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATileArray`): * ​b\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTileArray`): * ​a\_scales\_tiles (`BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTileArray`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TilePayload`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/TilePayload), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AScalesTile` `comptime AScalesTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTileArray.Tile` ### `AScalesTileArray` `comptime AScalesTileArray = SMemTileArray2DRowMajor[a_scales_type, a_scales_dim0, a_scales_dim1, num_pipeline_stages]` ### `ATile` `comptime ATile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]` ### `BTile` `comptime BTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]` ## Methods ### `__init__` `__init__(a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages], a_scales_tiles: SMemTileArray2DRowMajor[a_scales_type, a_scales_dim0, a_scales_dim1, num_pipeline_stages]) -> Self` ### `get_tile` `get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTile]` Get A, B, A-scales tiles at the specified stage and k-group index. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `get_a_tile` `get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].ATile` Get A tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload` ### `get_b_tile` `get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].BTile` Get B tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload` ### `get_a_scales_tile` `get_a_scales_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_dim0, a_dim1, b_dim0, b_dim1, a_scales_dim0, a_scales_dim1, num_pipeline_stages].AScalesTile` Get A-scales tile at the specified stage and k-group index. **Returns:** `BlockwiseFP8TilePayload`
--- ## SMemTileArray
`@register_passable(trivial)` `struct SMemTileArray[dtype: DType, shape_types: Variadic[CoordLike], stride_types: Variadic[CoordLike], num_tiles: Int, alignment: Int = 128]` Array of TileTensor tiles with variadic shape/stride type parameters. This is the TileTensor equivalent of the LayoutTensor-based SMemTileArray in structuring.mojo. By taking shape\_types and stride\_types directly as variadic type parameters, this preserves full compile-time type information including swizzle patterns. Example: comptime a\_layout = internal\_k\_major\[dtype, BM, BK, 128] comptime ATileArray = SMemTileArray\[ dtype, a\_layout.shape\_types, a\_layout.stride\_types, num\_pipeline\_stages, ] var array = ATileArray.stack\_allocation() var tile = array\[0] # Returns TileTensor with correct swizzled layout ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Tile element data type. * ​shape\_types ([`Variadic`](/mojo/std/builtin/variadics/Variadic)): Variadic shape types from Layout (preserves compile-time info). * ​stride\_types ([`Variadic`](/mojo/std/builtin/variadics/Variadic)): Variadic stride types from Layout (preserves compile-time info). * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles in the array. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment (default 128 for shared memory). ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_elements` `comptime num_elements = (SMemTileArray[dtype, shape_types, stride_types, num_tiles, alignment].tile_size * num_tiles)` ### `Storage` `comptime Storage = InlineArray[Scalar[dtype], SMemTileArray[dtype, shape_types, stride_types, num_tiles, alignment].num_elements]` ### `storage_size` `comptime storage_size = (SMemTileArray[dtype, shape_types, stride_types, num_tiles, alignment].num_elements * size_of[dtype]())` ### `Tile` `comptime Tile = TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]` ### `tile_size` `comptime tile_size = Coord[shape_types].static_product` ### `TileLayout` `comptime TileLayout = Layout[shape_types, stride_types]` ## Methods ### `__init__` `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArray[dtype, shape_types, stride_types, num_tiles, alignment].num_elements]) -> Self` Initialize from inline storage. **Args:** * ​storage ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): The inline storage array. **Returns:** `Self`: A new SMemTileArray pointing to the storage. `__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> Self` Initialize with a shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory storage. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArray[dtype, shape_types, stride_types, num_tiles, alignment].Tile` Get tile at the given index. **Args:** * ​index (`T`): The tile index. **Returns:** [`SMemTileArray`](/mojo/kernels/linalg/structuring/SMemTileArray): A TileTensor with correct swizzled layout at the given index. ### `slice` `slice[length: Int](self, start: Int) -> SMemTileArray[dtype, shape_types, stride_types, length, alignment]` Get a slice of the array. **Parameters:** * ​length ([`Int`](/mojo/std/builtin/int/Int)): The length of the slice. **Args:** * ​start ([`Int`](/mojo/std/builtin/int/Int)): The starting index. **Returns:** [`SMemTileArray`](/mojo/kernels/linalg/structuring/SMemTileArray): A new SMemTileArray representing the slice. ### `stack_allocation` `static stack_allocation() -> Self` Allocate the array on the stack (in shared memory). **Returns:** `Self`: A new SMemTileArray backed by stack-allocated shared memory.
--- ## SMemTileArray2D
`@register_passable(trivial)` `struct SMemTileArray2D[dtype: DType, dim0: Int, dim1: Int, num_tiles: Int, swizzle_bytes: Int = 128, alignment: Int = 128]` Array of TileTensor tiles in shared memory with swizzled K-major layout. The tiles use `internal_k_major` layout with configurable swizzle, matching the SM100 TMA swizzle pattern. This preserves swizzle information in the TileTensor type while using simple dimension-based parameters. Note: For tiles without swizzle, use SMemTileArrayWithLayout with row\_major. Example: comptime MyArray = SMemTileArray2D\[DType.float16, 64, 32, 4, 128, 128] var array = MyArray.stack\_allocation() var tile = array\[0] # Returns TileTensor with swizzled layout ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Tile element data type. * ​dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension (rows, e.g., BM or BN). * ​dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension (columns, e.g., BK). * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles in the array. * ​swizzle\_bytes ([`Int`](/mojo/std/builtin/int/Int)): Swizzle size in bytes (128, 64, or 32). Must be > 0. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment (default 128 for shared memory). ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_elements` `comptime num_elements = (SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].tile_size * num_tiles)` ### `Storage` `comptime Storage = InlineArray[Scalar[dtype], SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements]` ### `storage_size` `comptime storage_size = (SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements * size_of[dtype]())` ### `Tile` `comptime Tile = TileTensor[dtype, Layout[Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]], MutAnyOrigin, address_space=AddressSpace.SHARED]` ### `tile_layout` `comptime tile_layout = Layout[Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]]](Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(dim0 // 8)], ComptimeInt[8]](Idx[(dim0 // 8)](), Idx[8]())), Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]](Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((dim1 * size_of[dtype]()) // swizzle_bytes)]())))), Coord[Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]](Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` ### `tile_size` `comptime tile_size = (dim0 * dim1)` ## Methods ### `__init__` `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements]) -> Self` Initialize from inline storage. **Args:** * ​storage ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): The inline storage array. **Returns:** `Self`: A new SMemTileArray2D pointing to the storage. `__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> Self` Initialize with a shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory storage. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].Tile` Get tile at the given index. **Args:** * ​index (`T`): The tile index. **Returns:** [`SMemTileArray2D`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2D): A TileTensor-based tile at the given index with swizzled layout. ### `get_with_layout` `get_with_layout[tile_layout: Layout[shape_types, stride_types], T: Intable](self, index: T) -> TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]` Get tile at the given index with a specified layout. This method allows getting tiles with a swizzled layout for MMA operations, where the layout information is needed for correct K-iteration offsets. **Parameters:** * ​tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to use (e.g., swizzled layout for MMA). * ​T ([`Intable`](/mojo/std/builtin/int/Intable)): Index type (must be Intable). **Args:** * ​index (`T`): The tile index. **Returns:** [`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor): A TileTensor with the specified layout at the given index. ### `slice` `slice[length: Int](self, start: Int) -> SMemTileArray2D[dtype, dim0, dim1, length, alignment]` Get a slice of the array. **Parameters:** * ​length ([`Int`](/mojo/std/builtin/int/Int)): The length of the slice. **Args:** * ​start ([`Int`](/mojo/std/builtin/int/Int)): The starting index. **Returns:** [`SMemTileArray2D`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2D): A new SMemTileArray2D representing the slice. ### `stack_allocation` `static stack_allocation() -> Self` Allocate the array on the stack (in shared memory). **Returns:** `Self`: A new SMemTileArray2D backed by stack-allocated shared memory.
--- ## SMemTileArray2DRowMajor
`@register_passable(trivial)` `struct SMemTileArray2DRowMajor[dtype: DType, dim0: Int, dim1: Int, num_tiles: Int, alignment: Int = 128]` Array of TileTensor tiles in shared memory with row\_major layout. Unlike SMemTileArray2D which uses swizzled internal\_k\_major layout, this type uses simple row\_major layout. Suitable for 1D vectors (like A-scales) or output tiles where swizzling is not needed. Example: comptime MyArray = SMemTileArray2DRowMajor\[DType.float32, 1, 64, 4] var array = MyArray.stack\_allocation() var tile = array\[0] # Returns TileTensor with row\_major layout ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Tile element data type. * ​dim0 ([`Int`](/mojo/std/builtin/int/Int)): First dimension (rows). * ​dim1 ([`Int`](/mojo/std/builtin/int/Int)): Second dimension (columns). * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles in the array. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment (default 128 for shared memory). ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_elements` `comptime num_elements = (SMemTileArray2DRowMajor[dtype, dim0, dim1, num_tiles, alignment].tile_size * num_tiles)` ### `Storage` `comptime Storage = InlineArray[Scalar[dtype], SMemTileArray2DRowMajor[dtype, dim0, dim1, num_tiles, alignment].num_elements]` ### `storage_size` `comptime storage_size = (SMemTileArray2DRowMajor[dtype, dim0, dim1, num_tiles, alignment].num_elements * size_of[dtype]())` ### `Tile` `comptime Tile = TileTensor[dtype, Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]` ### `tile_layout` `comptime tile_layout = row_major[dim0, dim1]()` ### `tile_size` `comptime tile_size = (dim0 * dim1)` ## Methods ### `__init__` `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArray2DRowMajor[dtype, dim0, dim1, num_tiles, alignment].num_elements]) -> Self` Initialize from inline storage. **Args:** * ​storage ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): The inline storage array. **Returns:** `Self`: A new SMemTileArray2DRowMajor pointing to the storage. `__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> Self` Initialize with a shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory storage. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArray2DRowMajor[dtype, dim0, dim1, num_tiles, alignment].Tile` Get tile at the given index. **Args:** * ​index (`T`): The tile index. **Returns:** [`SMemTileArray2DRowMajor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2DRowMajor): A TileTensor-based tile at the given index with row\_major layout. ### `slice` `slice[length: Int](self, start: Int) -> SMemTileArray2DRowMajor[dtype, dim0, dim1, length, alignment]` Get a slice of the array. **Parameters:** * ​length ([`Int`](/mojo/std/builtin/int/Int)): The length of the slice. **Args:** * ​start ([`Int`](/mojo/std/builtin/int/Int)): The starting index. **Returns:** [`SMemTileArray2DRowMajor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArray2DRowMajor): A new SMemTileArray2DRowMajor representing the slice. ### `stack_allocation` `static stack_allocation() -> Self` Allocate the array on the stack (in shared memory). **Returns:** `Self`: A new SMemTileArray2DRowMajor backed by stack-allocated shared memory.
--- ## SMemTileArrayWithLayout
`@register_passable(trivial)` `struct SMemTileArrayWithLayout[shape_types: Variadic[CoordLike], stride_types: Variadic[CoordLike], //, dtype: DType, tile_layout: Layout[shape_types, stride_types], num_tiles: Int, alignment: Int = 128]` Array of TileTensor tiles with explicit Layout (preserves swizzle info). Unlike SMemTileArray2D which uses row\_major internally, this type preserves the full layout information from TMA swizzling, enabling .to\_layout\_tensor() to produce correctly swizzled LayoutTensors. Example: comptime swizzled\_layout = tile\_layout\_k\_major[dtype, BM, BK, ...]() comptime MyArray = SMemTileArrayWithLayout\[dtype, swizzled\_layout, 4] var array = MyArray.stack\_allocation() var tile = array\[0] # Returns TileTensor with swizzled layout var lt = tile.to\_layout\_tensor() # Correctly swizzled! ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Tile element data type. * ​tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The full layout including swizzle information. * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles in the array. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment (default 128 for shared memory). ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_elements` `comptime num_elements = (SMemTileArrayWithLayout[dtype, tile_layout, num_tiles, alignment].tile_size * num_tiles)` ### `Storage` `comptime Storage = InlineArray[Scalar[dtype], SMemTileArrayWithLayout[dtype, tile_layout, num_tiles, alignment].num_elements]` ### `storage_size` `comptime storage_size = (SMemTileArrayWithLayout[dtype, tile_layout, num_tiles, alignment].num_elements * size_of[dtype]())` ### `Tile` `comptime Tile = TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]` ### `tile_size` `comptime tile_size = tile_layout.product[shape_types, stride_types]()` ## Methods ### `__init__` `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArrayWithLayout[dtype, tile_layout, num_tiles, alignment].num_elements]) -> Self` Initialize from inline storage. **Args:** * ​storage ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): The inline storage array. **Returns:** `Self`: A new SMemTileArrayWithLayout pointing to the storage. `__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> Self` Initialize with a shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory storage. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArrayWithLayout[dtype, tile_layout, num_tiles, alignment].Tile` Get tile at the given index. **Args:** * ​index (`T`): The tile index. **Returns:** [`SMemTileArrayWithLayout`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArrayWithLayout): A TileTensor with correct swizzled layout at the given index. ### `slice` `slice[length: Int](self, start: Int) -> SMemTileArrayWithLayout[dtype, tile_layout, length, alignment]` Get a slice of the array. **Parameters:** * ​length ([`Int`](/mojo/std/builtin/int/Int)): The length of the slice. **Args:** * ​start ([`Int`](/mojo/std/builtin/int/Int)): The starting index. **Returns:** [`SMemTileArrayWithLayout`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArrayWithLayout): A new SMemTileArrayWithLayout representing the slice. ### `stack_allocation` `static stack_allocation() -> Self` Allocate the array on the stack (in shared memory). **Returns:** `Self`: A new SMemTileArrayWithLayout backed by stack-allocated shared memory.
--- ## tile_types
Native TileTensor types for SM100 structured kernels. This module provides TileTensor-based tile types for SM100 structured kernels. All SMEM storage uses TileTensor natively. Conversion to LayoutTensor only happens at external API boundaries (TMA, MMA) using explicit LayoutTensor construction from the tile pointer. Usage: from linalg.matmul.gpu.sm100\_structured.structured\_kernels.tile\_types import ( SMemTile, SMemTileArray2D, SMemTileArrayWithLayout ) ``` # Create tile with a layout comptime my_layout = row_major[64, 32]() comptime MyTile = SMemTile[DType.float16, my_layout] # At TMA/MMA boundaries, construct LayoutTensor from pointer comptime lt_type = LayoutTensor[dtype, layout, ...] tma_op.async_load(lt_type(tile.ptr), barrier, coords) ``` ## `comptime` values ### `internal_k_major` `comptime internal_k_major[dtype: DType, BM: Int, BK: Int, swizzle_bytes: Int] = Layout[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // swizzle_bytes)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // swizzle_bytes)]]](Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BM // 8)], ComptimeInt[8]](Idx[(BM // 8)](), Idx[8]())), Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // swizzle_bytes)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // swizzle_bytes)]](Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // swizzle_bytes)]())))), Coord[Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]](Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​BM ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): * ​swizzle\_bytes ([`Int`](/std/builtin/int/Int)): ### `internal_k_major_128B` `comptime internal_k_major_128B[dtype: DType, BM: Int, BK: Int] = Layout[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 128)]], Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BM // 8) * (128 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 128)]]](Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BM // 8)], ComptimeInt[8]](Idx[(BM // 8)](), Idx[8]())), Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 128)]](Idx[(128 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BM // 8) * (128 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BM // 8) * (128 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BM // 8) * (128 // size_of[dtype]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[dtype]())], ComptimeInt[((BM // 8) * (128 // size_of[dtype]()))]](Idx[(128 // size_of[dtype]())](), Idx[((BM // 8) * (128 // size_of[dtype]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​BM ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): ### `internal_k_major_32B` `comptime internal_k_major_32B[dtype: DType, BM: Int, BK: Int] = Layout[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 32)]], Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BM // 8) * (32 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 32)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 32)]]](Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BM // 8)], ComptimeInt[8]](Idx[(BM // 8)](), Idx[8]())), Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 32)]](Idx[(32 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 32)]())))), Coord[Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BM // 8) * (32 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BM // 8) * (32 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BM // 8) * (32 // size_of[dtype]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(32 // size_of[dtype]())], ComptimeInt[((BM // 8) * (32 // size_of[dtype]()))]](Idx[(32 // size_of[dtype]())](), Idx[((BM // 8) * (32 // size_of[dtype]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​BM ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): ### `internal_k_major_64B` `comptime internal_k_major_64B[dtype: DType, BM: Int, BK: Int] = Layout[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 64)]], Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BM // 8) * (64 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 64)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]], Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 64)]]](Coord[ComptimeInt[(BM // 8)], ComptimeInt[8]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(BM // 8)], ComptimeInt[8]](Idx[(BM // 8)](), Idx[8]())), Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 64)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BK * size_of[dtype]()) // 64)]](Idx[(64 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 64)]())))), Coord[Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BM // 8) * (64 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BM // 8) * (64 // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]](Coord[ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BM // 8) * (64 // size_of[dtype]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(64 // size_of[dtype]())], ComptimeInt[((BM // 8) * (64 // size_of[dtype]()))]](Idx[(64 // size_of[dtype]())](), Idx[((BM // 8) * (64 // size_of[dtype]()))]())), Coord[ComptimeInt[1], ComptimeInt[0]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0]](Idx[1](), Idx[0]())))))` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​BM ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): ### `internal_k_major_none` `comptime internal_k_major_none[dtype: DType, BM: Int, BK: Int] = row_major[BM, BK]()` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​BM ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): ### `internal_sf_k_major` `comptime internal_sf_k_major[dim0: Int, dim1: Int] = Layout[Coord[ComptimeInt[32], ComptimeInt[(dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(dim1 // 16)]], Coord[ComptimeInt[16], ComptimeInt[(dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[Coord[ComptimeInt[32], ComptimeInt[(dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(dim1 // 16)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[32], ComptimeInt[(dim0 // 32)]], Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(dim1 // 16)]]](Coord[ComptimeInt[32], ComptimeInt[(dim0 // 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[32], ComptimeInt[(dim0 // 32)]](Idx[32](), Idx[(dim0 // 32)]())), Coord[Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(dim1 // 16)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[4], ComptimeInt[4]], ComptimeInt[(dim1 // 16)]](Coord[ComptimeInt[4], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[4], ComptimeInt[4]](Idx[4](), Idx[4]())), Idx[(dim1 // 16)]())))), Coord[Coord[ComptimeInt[16], ComptimeInt[(dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[16], ComptimeInt[(dim1 * 32)]], Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]]](Coord[ComptimeInt[16], ComptimeInt[(dim1 * 32)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[16], ComptimeInt[(dim1 * 32)]](Idx[16](), Idx[(dim1 * 32)]())), Coord[Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[1], ComptimeInt[4]], ComptimeInt[512]](Coord[ComptimeInt[1], ComptimeInt[4]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[4]](Idx[1](), Idx[4]())), Idx[512]())))))` #### Parameters * ​dim0 ([`Int`](/std/builtin/int/Int)): * ​dim1 ([`Int`](/std/builtin/int/Int)): ### `SMemTile` `comptime SMemTile[shape_types: Variadic[CoordLike], stride_types: Variadic[CoordLike], //, dtype: DType, layout: Layout[shape_types, stride_types], *, alignment: Int = 128] = TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]` Shared memory tile using TileTensor with a Layout. The Layout parameter preserves swizzle information, enabling .to\_layout\_tensor() to produce correctly swizzled LayoutTensors. #### Parameters * ​shape\_types (`Variadic`): * ​stride\_types (`Variadic`): * ​dtype ([`DType`](/std/builtin/dtype/DType)): The data type of tile elements. * ​layout ([`Layout`](/kernels/layout/layout/Layout)): The full layout including swizzle information. * ​alignment ([`Int`](/std/builtin/int/Int)): Memory alignment (default 128 for shared memory). ### `SMemTile2D` `comptime SMemTile2D[dtype: DType, dim0: Int, dim1: Int, *, alignment: Int = 128] = TileTensor[dtype, Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]` Backward-compatible alias for SMemTile with explicit 2D dimensions. #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​dim0 ([`Int`](/std/builtin/int/Int)): * ​dim1 ([`Int`](/std/builtin/int/Int)): * ​alignment ([`Int`](/std/builtin/int/Int)): ### `SMemTileShape` `comptime SMemTileShape[mut: Bool, dtype: DType, LayoutType: TensorLayout, origin: Origin[mut=mut], address_space: AddressSpace, linear_idx_type: DType, element_shape_types: Variadic[CoordLike], //, idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]] = LayoutType.static_shape[idx]` Get compile-time shape value at index from a TileTensor type. Returns: The static shape value, or -1 if runtime-determined. #### Parameters * ​mut ([`Bool`](/std/builtin/bool/Bool)): * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​LayoutType ([`TensorLayout`](/kernels/layout/_layout/TensorLayout)): * ​origin ([`Origin`](/std/builtin/type_aliases/Origin)): * ​address\_space ([`AddressSpace`](/std/memory/pointer/AddressSpace)): * ​linear\_idx\_type ([`DType`](/std/builtin/dtype/DType)): * ​element\_shape\_types (`Variadic`): * ​idx ([`Int`](/std/builtin/int/Int)): The dimension index. * ​Tile ([`TileTensor`](/kernels/layout/_tile_tensor/TileTensor)): The TileTensor type (use type\_of(tile)). ### `SMemTileStride` `comptime SMemTileStride[mut: Bool, dtype: DType, LayoutType: TensorLayout, origin: Origin[mut=mut], address_space: AddressSpace, linear_idx_type: DType, element_shape_types: Variadic[CoordLike], //, idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]] = LayoutType.static_stride[idx]` Get compile-time stride value at index from a TileTensor type. Returns: The static stride value, or -1 if runtime-determined. #### Parameters * ​mut ([`Bool`](/std/builtin/bool/Bool)): * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​LayoutType ([`TensorLayout`](/kernels/layout/_layout/TensorLayout)): * ​origin ([`Origin`](/std/builtin/type_aliases/Origin)): * ​address\_space ([`AddressSpace`](/std/memory/pointer/AddressSpace)): * ​linear\_idx\_type ([`DType`](/std/builtin/dtype/DType)): * ​element\_shape\_types (`Variadic`): * ​idx ([`Int`](/std/builtin/int/Int)): The dimension index. * ​Tile ([`TileTensor`](/kernels/layout/_tile_tensor/TileTensor)): The TileTensor type (use type\_of(tile)). ### `swizzle_mode_to_bytes` `comptime swizzle_mode_to_bytes[swizzle_mode: TensorMapSwizzle] = 128 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_128B) else 64 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_64B) else 32 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_32B) else 0` Convert TensorMapSwizzle enum to swizzle size in bytes. Returns: The swizzle size in bytes (128, 64, 32, or 0 for no swizzle). #### Parameters * ​swizzle\_mode ([`TensorMapSwizzle`](/std/gpu/host/nvidia/tma/TensorMapSwizzle)): The TensorMapSwizzle enum value. ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`BlockwiseFP8TilePayload`](./BlockwiseFP8TilePayload): TileTensor-based tile payload for blockwise FP8 matmul. * [​`SMemTileArray`](./SMemTileArray): Array of TileTensor tiles with variadic shape/stride type parameters. * [​`SMemTileArray2D`](./SMemTileArray2D): Array of TileTensor tiles in shared memory with swizzled K-major layout. * [​`SMemTileArray2DRowMajor`](./SMemTileArray2DRowMajor): Array of TileTensor tiles in shared memory with row\_major layout. * [​`SMemTileArrayWithLayout`](./SMemTileArrayWithLayout): Array of TileTensor tiles with explicit Layout (preserves swizzle info).
--- ## BlockScaledTmem
`@register_passable(trivial)` `struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = 1, total_cols: Int = 512, num_sf_k_tiles: Int = 1]` TMEM region for block-scaled matmul with typed tile accessors. Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to: * Accumulator tiles (one per output pipeline stage) * SFA scaling factor tiles (one per k-iteration) * SFB scaling factor tiles (one per k-iteration) Memory layout (512 columns total): ┌────────────────────────────────────────────────────────────┐ │ Accumulators │ SFA Scales │ SFB Scales │ │ (stages × MMA\_N) │ (iters × cols) │ (iters × cols) │ └────────────────────────────────────────────────────────────┘ ## Parameters * ​accum\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Accumulator data type (typically float32). * ​MMA\_M ([`Int`](/mojo/std/builtin/int/Int)): MMA M dimension. * ​MMA\_N ([`Int`](/mojo/std/builtin/int/Int)): MMA N dimension (also stage stride for accumulators). * ​num\_accum\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​sf\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Scaling factor data type. * ​BM ([`Int`](/mojo/std/builtin/int/Int)): Block M dimension (for SFA sizing). * ​num\_pipeline\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of k-iteration pipeline stages. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). * ​total\_cols ([`Int`](/mojo/std/builtin/int/Int)): Total TMEM columns (512 for SM100). * ​num\_sf\_k\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Scaling factor tiles per K-iteration. MXFP8 uses 1 (one SF vector per K-tile). NVFP4 uses 4 (multiple SF vectors per K-tile). ## Fields * ​base\_addr (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_layout` `comptime accum_layout = Layout.row_major(MMA_M, MMA_N)` ### `accum_offset` `comptime accum_offset = 0` ### `AccumArray` `comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].accum_layout, num_accum_stages, cta_group=cta_group]` ### `AccumTile` `comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].AccumArray.Tile` ### `sfa_layout` `comptime sfa_layout = Layout.row_major(1, (num_sf_k_tiles * (BM // 32)))` ### `sfa_offset` `comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].AccumArray.num_cols` ### `SFAArray` `comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].sfa_layout, num_pipeline_stages, cta_group=cta_group]` ### `SFATile` `comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFAArray.Tile` ### `sfb_layout` `comptime sfb_layout = Layout.row_major(1, (num_sf_k_tiles * (MMA_N // 32)))` ### `sfb_offset` `comptime sfb_offset = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].sfa_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFAArray.num_cols)` ### `SFBArray` `comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].sfb_layout, num_pipeline_stages, cta_group=cta_group]` ### `SFBTile` `comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFBArray.Tile` ### `used_cols` `comptime used_cols = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].sfb_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFBArray.num_cols)` ## Methods ### `__init__` `__init__(base_addr: Int) -> Self` Create TMEM region view at the given base address. `__init__(addr: TmemAddress) -> Self` Create TMEM region view from a TmemAddress. `__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self` Create TMEM region view from a TmemAllocation. ### `accum_tiles` `accum_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].AccumArray` Get array of accumulator tiles. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem) ### `sfa_tiles` `sfa_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFAArray` Get array of SFA scaling factor tiles. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem) ### `sfb_tiles` `sfb_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFBArray` Get array of SFB scaling factor tiles. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem) ### `accum` `accum[T: Intable](self, stage: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].AccumTile` Get accumulator tile for the given pipeline stage. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem) ### `sfa` `sfa[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFATile` Get SFA scaling factor tile for the given k-iteration index. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem) ### `sfb` `sfb[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles].SFBTile` Get SFB scaling factor tile for the given k-iteration index. **Returns:** [`BlockScaledTmem`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/BlockScaledTmem)
--- ## TmemAddress
`@register_passable(trivial)` `struct TmemAddress` Simple TMEM address wrapper for load/store operations. Encapsulates TMEM address encoding for accumulator fragment access. SM100 MMA accumulators are organized as 32 rows, split into: * Upper fragment (rows 0-15): accessed via upper\_addr() * Lower fragment (rows 16-31): accessed via lower\_addr() The lower fragment address adds TMEM\_LOWER\_ROW\_OFFSET (16 << 16) to encode the row offset in the upper 16 bits of the address. Usage: var tmem = TmemAddress(base\_offset) ``` # Load operations var upper = tmem.load_upper[dtype, size]() var lower = tmem.load_lower[dtype, size]() TmemAddress.wait_load() # Store operations tmem.store_upper[dtype, size](upper_frag) tmem.store_lower[dtype, size](lower_frag) TmemAddress.wait_store() # Low-level address access for custom operations raw_upper = tmem.upper_addr() raw_lower = tmem.lower_addr() ``` ## Fields * ​addr (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(addr: Int) -> Self` Create TmemAddress from integer column address. `__init__(addr: UInt32) -> Self` Create TmemAddress from hardware address (UInt32). ### `__add__` `__add__(self, offset: Int) -> Self` Create new TmemAddress with column offset added. ### `upper_addr` `upper_addr(self) -> UInt32` Raw address for upper fragment (rows 0-15). **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `lower_addr` `lower_addr(self) -> UInt32` Raw address for lower fragment (rows 16-31). **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `load_upper` `load_upper[dtype: DType, width: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 1](self) -> SIMD[dtype, width]` Load upper accumulator fragment (rows 0-15). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `load_lower` `load_lower[dtype: DType, width: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 1](self) -> SIMD[dtype, width]` Load lower accumulator fragment (rows 16-31). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `store_upper` `store_upper[dtype: DType, width: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 1](self, data: SIMD[dtype, width])` Store upper accumulator fragment (rows 0-15). ### `store_lower` `store_lower[dtype: DType, width: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 1](self, data: SIMD[dtype, width])` Store lower accumulator fragment (rows 16-31). ### `wait_store` `static wait_store()` Wait for TMEM store operations to complete. ### `wait_load` `static wait_load()` Wait for TMEM load operations to complete.
--- ## TmemAllocation
`@register_passable(trivial)` `struct TmemAllocation[cta_group: Int, max_cols: Int = 512]` Handle to allocated Tensor Memory. Lifecycle: allocate() → use → release\_lock() → wait → deallocate() ## Parameters * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Cooperating CTAs (1 or 2). * ​max\_cols ([`Int`](/mojo/std/builtin/int/Int)): TMEM columns (512 for SM100). ## Fields * ​addr (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SmemAddrStorage` `comptime SmemAddrStorage = SMemArray[UInt32, 1]` ## Methods ### `__init__` `__init__(addr: UInt32) -> Self` ### `allocate` `static allocate(smem_addr: SMemArray[UInt32, 1]) -> Self` Allocate TMEM (MMA warp). Address stored in smem for epilogue. ### `from_shared` `static from_shared(smem_addr: SMemArray[UInt32, 1]) -> Self` Get handle from existing allocation (epilogue warp). ### `release_lock` `release_lock(self)` Release allocation lock before waiting for epilogue. ### `deallocate` `deallocate(self)` Free TMEM after epilogue completion.
--- ## TmemArrayType
`@register_passable(trivial)` `struct TmemArrayType[dtype: DType, layout: Layout, num_tiles: Int, *, cta_group: Int = 1]` Array of tiles in Tensor Memory (TMEM). Similar to SMemArray but for TMEM-resident tiles. Provides indexed access to a contiguous array of TmemTensor tiles. Compile-time constants: Tile: TmemTensor type for each tile. tile\_stride: Columns per tile (derived from layout.size()). num\_cols: Total TMEM columns used (num\_tiles × tile\_stride). ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Element dtype for tiles. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of each tile. * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles in the array. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). ## Fields * ​base\_addr (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_cols` `comptime num_cols = (num_tiles * TmemArrayType[dtype, layout, num_tiles, cta_group=cta_group].tile_stride)` ### `Tile` `comptime Tile = TmemTensor[dtype, layout, cta_group=cta_group]` ### `tile_stride` `comptime tile_stride = layout.shape[1].value()` ## Methods ### `__init__` `__init__(base_addr: Int) -> Self` Initialize array at the given TMEM base address. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> TmemArrayType[dtype, layout, num_tiles, cta_group=cta_group].Tile` Get tile at the given index. **Returns:** `TmemArrayType`
--- ## TmemFragments
`@register_passable(trivial)` `struct TmemFragments[dtype: DType, frag_size: Int, *, is_lower_required: Bool = True, data_paths: Int = 16, bits: Int = 256]` Paired upper/lower accumulator fragments from TMEM. Encapsulates the SM100 TMEM row-split hardware detail: * Upper fragment: rows 0-15 (always present) * Lower fragment: rows 16-31 (only when is\_lower\_required=True) The is\_lower\_required flag is determined by: * False when cta\_group=1 and MMA\_M=64 (fits in 16 rows) * True otherwise (needs both halves) Example: # Load both fragments in one call var frags = TmemFragments\[DType.float32, 16].load(tmem\_addr) # Work with fragments frags.upper = process(frags.upper) frags.lower = process(frags.lower) # Store both fragments frags.store(tmem\_addr) TmemFragments.wait\_store() ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Fragment data type (typically float32). * ​frag\_size ([`Int`](/mojo/std/builtin/int/Int)): Elements per fragment (derived from data\_paths and bits). * ​is\_lower\_required ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether lower fragment is needed. * ​data\_paths ([`Int`](/mojo/std/builtin/int/Int)): SM100 data paths (typically 16). * ​bits ([`Int`](/mojo/std/builtin/int/Int)): Bits per fragment load (typically 256). ## Fields * ​upper (`SIMD[dtype, frag_size]`): * ​lower (`SIMD[dtype, frag_size]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Initialize with zero fragments. `__init__(upper: SIMD[dtype, frag_size], lower: SIMD[dtype, frag_size]) -> Self` Initialize with provided fragments. ### `load` `static load[repeat: Int = 1](tmem: TmemAddress) -> TmemFragments[dtype, (frag_size * repeat), is_lower_required=is_lower_required]` Load fragments from TMEM address. Loads upper fragment always; loads lower only if required. **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the load pattern. **Args:** * ​tmem ([`TmemAddress`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemAddress)): TMEM address to load from. **Returns:** [`TmemFragments`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemFragments): TmemFragments containing upper and (optionally) lower data. ### `store` `store[repeat: Int = 1](self, tmem: TmemAddress)` Store fragments to TMEM address. Stores upper fragment always; stores lower only if required. **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the store pattern. **Args:** * ​tmem ([`TmemAddress`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemAddress)): TMEM address to store to. ### `cast` `cast[target_dtype: DType](self) -> TmemFragments[target_dtype, frag_size, is_lower_required=is_lower_required]` Cast fragments to a different dtype. **Returns:** [`TmemFragments`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemFragments) ### `wait_load` `static wait_load()` Wait for TMEM load operations to complete. ### `wait_store` `static wait_store()` Wait for TMEM store operations to complete.
--- ## TmemStage
`@register_passable(trivial)` `struct TmemStage[num_stages: Int, stage_stride: Int, cta_group: Int]` A pipeline stage within TMEM for accumulator buffering. Used by OutputTilePipeline to manage MMA→Epilogue synchronization. MMA writes to one stage while epilogue reads from another. Wraps TmemAddress with stage-specific offset calculation: * offset(): Column address for this stage (base + index \* stride) * address(): TmemAddress for this stage (for load/store ops) * tensor[layout](): Get typed TmemTensor view ## Parameters * ​num\_stages ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stages (typically 2-4). * ​stage\_stride ([`Int`](/mojo/std/builtin/int/Int)): Columns per stage (512 / num\_stages). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): Cooperating CTAs (1 or 2). ## Fields * ​base\_addr (`Int`): * ​index (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(base_addr: Int, index: Int) -> Self` `__init__(addr: TmemAddress, index: Int) -> Self` Create stage from TmemAddress and stage index. `__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols], index: Int) -> Self` Create stage from TmemAllocation and stage index. ### `from_offset` `static from_offset(offset: Int, index: Int) -> Self` Create stage from pre-computed offset (for legacy pipeline compatibility). Use this when the caller has already computed the TMEM offset (e.g., `base + stage * stride`) and just needs to wrap it. The index is preserved for barrier signaling, and we back-calculate the base\_addr such that offset() = base + index \* stride = offset. **Args:** * ​offset ([`Int`](/mojo/std/builtin/int/Int)): Pre-computed TMEM column offset for this stage. * ​index ([`Int`](/mojo/std/builtin/int/Int)): Pipeline stage index (for barrier signaling). **Returns:** `Self`: TmemStage with offset() returning the given value. ### `offset` `offset(self) -> Int` TMEM column address for this stage. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `address` `address(self) -> TmemAddress` Get TmemAddress for this stage's offset. **Returns:** [`TmemAddress`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemAddress) ### `tensor` `tensor[accum_dtype: DType, accum_layout: Layout](self) -> TmemTensor[accum_dtype, accum_layout, cta_group=cta_group]` Get typed TmemTensor view of this stage's accumulator. **Parameters:** * ​accum\_dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Accumulator data type. * ​accum\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Logical accumulator layout (M × N). **Returns:** [`TmemTensor`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemTensor): TmemTensor providing typed access to the accumulator. ### `load_upper` `load_upper[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]` Load upper accumulator fragment (rows 0-15). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `load_lower` `load_lower[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]` Load lower accumulator fragment (rows 16-31). **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD) ### `wait_load` `static wait_load()` Wait for TMEM load operations to complete.
--- ## TmemTensor
`@register_passable(trivial)` `struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]` Typed tensor view over Tensor Memory (TMEM) for MMA accumulators. Provides a LayoutTensor-like abstraction for TMEM with: * Type safety: dtype and layout known at compile time * Fragment access: upper (rows 0-15) and lower (rows 16-31) * MMA integration: offset() returns raw address for MMA operations The layout parameter captures the logical accumulator shape (M × N), enabling future extensions like custom tiling patterns or multi-tile accumulator management. Example: # Create typed TMEM view with (64, 128) accumulator layout comptime layout = Layout.row\_major(64, 128) var tmem = TmemTensor[DType.float32, layout](col_offset) # Use with MMA operations (returns raw UInt32 offset) mma\_op.mma(a\_tile, b\_tile, tmem.offset(), init\_c=True) # Load fragments for epilogue var upper = tmem.load\_upper[repeat=4]() var lower = tmem.load\_lower[repeat=4]() TmemTensor.wait\_load() ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Accumulator data type (typically float32). * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Logical layout of the accumulator tile (M × N). * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA cooperation level (1 or 2). ## Fields * ​col\_addr (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `bits` `comptime bits = 256` ### `data_paths` `comptime data_paths = 16` ### `frag_size` `comptime frag_size = 4` ### `Fragments` `comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]` ### `is_lower_required` `comptime is_lower_required = (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (cta_group == 1)._mlir_value else (cta_group == 1).__bool__().__invert__()` ### `tile_m` `comptime tile_m = layout.shape[0].value()` ## Methods ### `__init__` `__init__(col_addr: Int) -> Self` Create TMEM tensor view at the given column address. `__init__(addr: TmemAddress) -> Self` Create TMEM tensor view from a TmemAddress. ### `offset` `offset(self) -> Int` TMEM column address for this tensor. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `address` `address(self) -> TmemAddress` Get TmemAddress for low-level fragment operations. **Returns:** [`TmemAddress`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemAddress) ### `load_upper` `load_upper[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]` Load upper accumulator fragment (rows 0-15). **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the load pattern. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): SIMD vector containing the upper fragment data. ### `load_lower` `load_lower[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]` Load lower accumulator fragment (rows 16-31). **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the load pattern. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): SIMD vector containing the lower fragment data. ### `store_upper` `store_upper[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])` Store upper accumulator fragment (rows 0-15). **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the store pattern. **Args:** * ​data ([`SIMD`](/mojo/std/builtin/simd/SIMD)): SIMD vector containing the data to store. ### `store_lower` `store_lower[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])` Store lower accumulator fragment (rows 16-31). **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the store pattern. **Args:** * ​data ([`SIMD`](/mojo/std/builtin/simd/SIMD)): SIMD vector containing the data to store. ### `load_fragments` `load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]` Load both upper and lower fragments in one call. Handles is\_lower\_required automatically based on layout. **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the load pattern. **Returns:** [`TmemFragments`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemFragments): TmemFragments containing upper and (conditionally) lower data. ### `store_fragments` `store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required])` Store both upper and lower fragments in one call. Handles is\_lower\_required automatically based on layout. **Parameters:** * ​repeat ([`Int`](/mojo/std/builtin/int/Int)): Number of times to repeat the store pattern. **Args:** * ​frags ([`TmemFragments`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tmem/TmemFragments)): TmemFragments containing upper and (conditionally) lower data. ### `wait_load` `static wait_load()` Wait for TMEM load operations to complete. ### `wait_store` `static wait_store()` Wait for TMEM store operations to complete.
--- ## tmem
Tensor Memory (TMEM) abstractions for SM100 Blackwell GPUs. TMEM is dedicated memory for MMA accumulators, separate from registers and shared memory. This module provides type-safe abstractions: * TmemAllocation: Manages TMEM lifecycle (alloc/dealloc) * TmemTensor: Layout-parameterized typed view over TMEM accumulators * TmemStage: Represents a pipeline stage for accumulator buffering * TmemAddress: Simple address wrapper for TMEM load operations ## `comptime` values ### `TMEM_LOWER_ROW_OFFSET` `comptime TMEM_LOWER_ROW_OFFSET = 1048576` ## Structs * [​`BlockScaledTmem`](./BlockScaledTmem): TMEM region for block-scaled matmul with typed tile accessors. * [​`TmemAddress`](./TmemAddress): Simple TMEM address wrapper for load/store operations. * [​`TmemAllocation`](./TmemAllocation): Handle to allocated Tensor Memory. * [​`TmemArrayType`](./TmemArrayType): Array of tiles in Tensor Memory (TMEM). * [​`TmemFragments`](./TmemFragments): Paired upper/lower accumulator fragments from TMEM. * [​`TmemStage`](./TmemStage): A pipeline stage within TMEM for accumulator buffering. * [​`TmemTensor`](./TmemTensor): Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.
--- ## EpilogueWarp
`struct EpilogueWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]` Unified linear type for epilogue warp lifecycle. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `create()` after Sync.wait() - reads TMEM address 2. Use `output_pipeline` or `acquire_k_stage_linear()` for epilogue stages 3. Must call `release()` to signal completion (compiler-enforced) IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible. ## Parameters * ​num\_accum\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​stage\_stride\_cols ([`Int`](/mojo/std/builtin/int/Int)): TMEM column stride between stages. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). * ​mma\_threads ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA threads. * ​epilogue\_threads ([`Int`](/mojo/std/builtin/int/Int)): Number of epilogue threads. ## Fields * ​tmem (`EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem`): * ​output\_pipeline (`EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline`): * ​dealloc\_barrier (`EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `Dealloc` `comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc` ### `Pipeline` `comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline` ### `Sync` `comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync` ### `Tmem` `comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem` ## Methods ### `__init__` `__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])` ### `create` `static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self` Create Epilogue warp. Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible. **Args:** * ​tmem\_addr\_storage ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Shared storage containing TMEM address. * ​accum\_barriers ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier array for accumulator pipeline. * ​dealloc\_mbar ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier for TMEM deallocation synchronization. * ​mma\_complete\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for MMA completion signaling. **Returns:** `Self`: Fully initialized EpilogueWarp that must be released. ### `per_k_stage` `per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[input_origin] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin._mlir_origin.pipeline), num_accum_stages, stage_stride_cols, cta_group, num_group_stages]` Get per-K stage context manager (for compatibility). Prefer acquire\_k\_stage\_linear() for flat code structure. **Returns:** `EpilogueKContext` ### `acquire_k_stage_linear` `acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]` Acquire a per-K stage using linear types. Waits for MMA to complete the stage, returns a linear handle. Usage: var stage = epi\_handle.acquire\_k\_stage\_linear() process\_tmem(stage.tmem()) stage^.release() **Returns:** `EpilogueStage` ### `release` `release(deinit self)` Signal epilogue completion. This is the only way to destroy this linear type.
--- ## EpilogueWarpContext
`@register_passable(trivial)` `struct EpilogueWarpContext[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]` Epilogue warp context - consumes TMEM data, signals completion. IMPORTANT: Call Sync.wait() BEFORE constructing to ensure TMEM address is visible from shared memory. ## Fields * ​tmem (`EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem`): * ​output\_pipeline (`EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline`): * ​dealloc\_barrier (`EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Dealloc` `comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc` ### `Pipeline` `comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline` ### `Sync` `comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync` ### `Tmem` `comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem` ## Methods ### `__init__` `__init__(tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group]) -> Self` ### `create` `static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self` Create Epilogue warp context with all necessary components. Reads TMEM address from shared memory and creates output pipeline. IMPORTANT: Call Sync.wait() BEFORE calling this to ensure TMEM address is visible. **Args:** * ​tmem\_addr\_storage ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Shared storage containing TMEM address. * ​accum\_barriers ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier array for accumulator pipeline. * ​dealloc\_mbar ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier for TMEM deallocation synchronization. * ​mma\_complete\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for MMA completion signaling. **Returns:** `Self`: Fully initialized EpilogueWarpContext. ### `__enter__` `__enter__(self) -> Self` ### `__exit__` `__exit__(self)` ### `per_k_stage` `per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[input_origin] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin._mlir_origin.pipeline), num_accum_stages, stage_stride_cols, cta_group, num_group_stages]` Get per-K stage context for blockwise FP8 epilogue. Bundles output pipeline (MMA→Epilogue sync) and input pipeline (A-scales consumption) into a single context manager. Example: for k\_iter in range(num\_iters): with epi\_ctx.per\_k\_stage(input\_pipeline) as epi\_stage: accum.promote(epi\_stage, ...) \# Both pipelines signaled automatically **Args:** * ​input\_pipeline ([`InputTilePipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_pipeline/InputTilePipeline)): The InputTilePipeline (extracts .pipeline internally). **Returns:** `EpilogueKContext`: EpilogueKContext context manager that handles both pipelines.
--- ## MmaWarp
`struct MmaWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]` Unified linear type for MMA warp TMEM lifecycle. Works as both a linear type (direct use) and within context managers. Lifecycle: 1. Created via `create()` - allocates TMEM, signals sync barrier 2. Use `output_pipeline` or `acquire_k_stage_linear()` for MMA stages 3. Must call `release()` to wait for epilogue and deallocate (compiler-enforced) ## Parameters * ​num\_accum\_stages ([`Int`](/mojo/std/builtin/int/Int)): Number of accumulator pipeline stages. * ​stage\_stride\_cols ([`Int`](/mojo/std/builtin/int/Int)): TMEM column stride between stages. * ​cta\_group ([`Int`](/mojo/std/builtin/int/Int)): CTA group size (1 or 2). * ​mma\_threads ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA threads. * ​epilogue\_threads ([`Int`](/mojo/std/builtin/int/Int)): Number of epilogue threads. ## Fields * ​tmem (`MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem`): * ​output\_pipeline (`MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline`): * ​dealloc\_barrier (`MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `Dealloc` `comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc` ### `Pipeline` `comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline` ### `Sync` `comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync` ### `Tmem` `comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem` ## Methods ### `__init__` `__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])` ### `create` `static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self` Create MMA warp with TMEM allocation. Allocates TMEM and signals the warp group sync barrier. **Args:** * ​tmem\_addr\_storage ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Shared storage for TMEM address communication. * ​accum\_barriers ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier array for accumulator pipeline. * ​dealloc\_mbar ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier for TMEM deallocation synchronization. * ​mma\_complete\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for MMA completion signaling. **Returns:** `Self`: Fully initialized MmaWarp that must be released. ### `per_k_stage` `per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]` Get per-K stage context manager (for compatibility). Prefer acquire\_k\_stage\_linear() for flat code structure. **Returns:** `MmaKStage` ### `acquire_k_stage_linear` `acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]` Acquire a per-K stage using linear types. Waits for epilogue to free the stage, returns a linear handle. Usage: var stage = mma\_handle.acquire\_k\_stage\_linear() mma\_op.mma(a, b, stage.tmem\_offset()) mma\_op.commit(stage.mbar()) stage^.release() **Returns:** `MmaStage` ### `release` `release(deinit self)` Wait for epilogue and deallocate TMEM. This is the only way to destroy this linear type.
--- ## MmaWarpContext
`@register_passable(trivial)` `struct MmaWarpContext[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]` MMA warp context - owns TMEM lifecycle and output pipeline. **enter**: Signals epilogue that TMEM is allocated **exit**: Waits for epilogue, deallocates TMEM ## Fields * ​tmem (`MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem`): * ​output\_pipeline (`MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline`): * ​dealloc\_barrier (`MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Dealloc` `comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc` ### `Pipeline` `comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline` ### `Sync` `comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync` ### `Tmem` `comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem` ## Methods ### `__init__` `__init__(tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group]) -> Self` ### `create` `static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self` Create MMA warp context with all necessary components. Allocates TMEM and creates output pipeline internally. **Args:** * ​tmem\_addr\_storage ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Shared storage for TMEM address communication. * ​accum\_barriers ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier array for accumulator pipeline. * ​dealloc\_mbar ([`SMemArray`](/mojo/kernels/linalg/structuring/SMemArray)): Barrier for TMEM deallocation synchronization. * ​mma\_complete\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Multicast mask for MMA completion signaling. **Returns:** `Self`: Fully initialized MmaWarpContext. ### `__enter__` `__enter__(self) -> Self` ### `__exit__` `__exit__(self)` ### `per_k_stage` `per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]` Get per-K stage for blockwise FP8 MMA loop. Returns a context manager that acquires an output stage and signals mma\_arrive on exit. Example: for i in range(num\_iters): with mma\_ctx.per\_k\_stage() as mma\_stage: mma(input\_tiles, mma\_op, AccumTensor(mma\_stage.tmem.offset())) \# **exit** signals mma\_arrive automatically **Returns:** `MmaKStage`
--- ## warp_context
RAII warp context managers for SM100 matmul kernel. MmaWarpContext: MMA warp - allocates TMEM, deallocates on exit EpilogueWarpContext: Epilogue warp - consumes TMEM, signals completion on exit ## Structs * [​`EpilogueWarp`](./EpilogueWarp): Unified linear type for epilogue warp lifecycle. * [​`EpilogueWarpContext`](./EpilogueWarpContext): Epilogue warp context - consumes TMEM data, signals completion. * [​`MmaWarp`](./MmaWarp): Unified linear type for MMA warp TMEM lifecycle. * [​`MmaWarpContext`](./MmaWarpContext): MMA warp context - owns TMEM lifecycle and output pipeline.
--- ## create_matmul_configs_ampere
`create_matmul_configs_ampere[key: String, a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool]() -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## get_dispatch_table
`get_dispatch_table[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool]() -> Dict[String, MatmulConfig[a_type, b_type, c_type, transpose_b], default_comp_time_hasher]` **Returns:** `Dict`
--- ## dispatch (Dispatch)
## Functions * [​`create_matmul_configs_ampere`](./create_matmul_configs_ampere): * [​`get_dispatch_table`](./get_dispatch_table):
--- ## sm80
Provides the CPU Hopper backend implementations for matmuls. ## Modules * [​`dispatch`](./dispatch/):
--- ## MatmulConfig (Config)
`@register_passable(trivial)` `struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True]` Static configuration of SM90 GPU matmul. ## Fields * ​block\_tile\_shape (`IndexList[3]`): * ​mma\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​num\_pipeline\_stages (`UInt`): * ​num\_k\_partitions (`UInt`): * ​num\_consumer (`UInt`): * ​partitioned\_multicast (`Bool`): * ​k\_group\_size (`UInt`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`Hashable`](/mojo/std/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_shape: IndexList[3], num_pipeline_stages: Scalar[DType.uint], num_k_partitions: Scalar[DType.uint], num_consumer: Scalar[DType.uint], partitioned_multicast: Bool, pdl_level: PDLLevel, k_group_size: Scalar[DType.uint]) -> Self` Initialize MatmulConfig with explicit values for all fields. `__init__(m: Int, n: Int, k: Int, num_k_partitions: Scalar[DType.uint] = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel.OFF, k_groups: Optional[UInt] = None, consumer_groups: Optional[Int] = None, swapAB: Bool = False) -> Self` Initialize MatmulConfig by computing optimal values from M, N, K. **Args:** * ​m ([`Int`](/mojo/std/builtin/int/Int)): The M dimension of the matmul. * ​n ([`Int`](/mojo/std/builtin/int/Int)): The N dimension of the matmul. * ​k ([`Int`](/mojo/std/builtin/int/Int)): The K dimension of the matmul. * ​num\_k\_partitions ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of K partitions. * ​partitioned\_multicast ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use partitioned multicast. * ​pdl\_level ([`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel)): PDL level for grid controls. * ​k\_groups ([`Optional`](/mojo/std/collections/optional/Optional)): How many pipeline (loads and stores) are grouped together. * ​consumer\_groups ([`Optional`](/mojo/std/collections/optional/Optional)): The number of consumer groups. * ​swapAB ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to swap A and B. ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `adjust_kgroup_size` `static adjust_kgroup_size(mma_m: Scalar[DType.uint], mma_n: Scalar[DType.uint], K: Scalar[DType.uint], BK: Scalar[DType.uint], num_pipeline_stages: Scalar[DType.uint]) -> UInt` **Returns:** `UInt` ### `pdl_level` `pdl_level(self) -> PDLLevel` **Returns:** [`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel) ### `to_base_config` `to_base_config(self) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` Convert to base MatmulConfig from utils\_gpu. **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/std/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## build_configs (Config)
`build_configs[a_type: DType, b_type: DType, c_type: DType, N: Int, K: Int, transpose_b: Bool = True, num_k_partitions: Scalar[DType.uint] = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel.OFF, k_groups: Optional[UInt] = None, consumer_groups: Optional[Int] = None, swapAB: Bool = False]() -> Set[MatmulConfig[a_type, b_type, c_type, transpose_b]]` **Returns:** `Set`
--- ## build_configs_generic
`build_configs_generic[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, //, M_start: Int, M_end: Int, config_fn: fn(Int) capturing -> MatmulConfig[a_type, b_type, c_type, transpose_b]]() -> Set[MatmulConfig[a_type, b_type, c_type, transpose_b]]` **Returns:** `Set`
--- ## config (4)
## Structs * [​`MatmulConfig`](./MatmulConfig): Static configuration of SM90 GPU matmul. ## Functions * [​`build_configs`](./build_configs): * [​`build_configs_generic`](./build_configs_generic): * [​`swapAB_largeM_clustered`](./swapAB_largeM_clustered): Config for m in \[129, 240] range with cluster=(2,1,1). * [​`swapAB_midM_linear`](./swapAB_midM_linear): Config for m in \[65, 128] range with linear BN pattern. * [​`swapAB_smallM`](./swapAB_smallM): * [​`swapAB_smallM_ceildiv`](./swapAB_smallM_ceildiv): Config for m < 41 range with BN = ceildiv(m, 8) \* 8 pattern.
--- ## swapAB_largeM_clustered
`swapAB_largeM_clustered[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True](m: Scalar[DType.uint], pdl_level: PDLLevel) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` Config for m in \[129, 240] range with cluster=(2,1,1). Pattern: \- BN = 72 + ((m - 129) // 16) \* 8 \- Stages: 12 for m<=160, 10 for m<=224, 8 otherwise \- cluster = (2,1,1), k\_group\_size = 2, swapAB = True **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## swapAB_midM_linear
`swapAB_midM_linear[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True](m: Scalar[DType.uint], pdl_level: PDLLevel) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` Config for m in \[65, 128] range with linear BN pattern. Pattern: \- BN = 40 + ((m - 65) // 16) \* 8 \- stages = 8, cluster = (1,1,1), swapAB = True **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## swapAB_smallM
`swapAB_smallM[a_type: DType, b_type: DType, c_type: DType, prioritize_compute_over_ctas: Bool = False, transpose_b: Bool = True](m: Scalar[DType.uint], n: Scalar[DType.uint], k: Scalar[DType.uint], cluster_shape: IndexList[3], num_k_partitions: Scalar[DType.uint], num_consumer: Scalar[DType.uint], partitioned_multicast: Bool, pdl_level: PDLLevel, k_group_size: Scalar[DType.uint] = 0, num_pipeline_stages: Scalar[DType.uint] = 0) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## swapAB_smallM_ceildiv
`swapAB_smallM_ceildiv[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True](m: Scalar[DType.uint], pdl_level: PDLLevel) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` Config for m < 41 range with BN = ceildiv(m, 8) \* 8 pattern. Pattern: \- BN = ceildiv(m, 8) \* 8 (rounds up to next multiple of 8) \- stages = 12, cluster = (1,1,1), swapAB = True **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## dispatch (3)
## `comptime` values ### `DISPATCH_HIT` `comptime DISPATCH_HIT = 1` ### `DISPATCH_MISS` `comptime DISPATCH_MISS = 0` ### `llama_405b_fp8_list` `comptime llama_405b_fp8_list = List[TuningConfigSM90](TuningConfigSM90(64, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 2304, 16384, IndexList[3, DType.int64](64, 48, 32, Tuple[]()), Index(64, 48, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 2304, 16384, IndexList[3, DType.int64](64, 48, 32, Tuple[]()), Index(64, 48, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 2304, 16384, IndexList[3, DType.int64](64, 96, 32, Tuple[]()), Index(64, 96, 128), 4, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(1, 1, 1), 2, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(1, 1, 1), 2, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(2048, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(16, 8)), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 2304, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 13312, 16384, IndexList[3, DType.int64](64, 208, 32, Tuple[]()), Index(128, 208, 128), 4, Index(1, 2, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), Tuple[]())` ### `llama_405b_fp8_table` `comptime llama_405b_fp8_table = Table[TuningConfigSM90](llama_405b_fp8_list, "llama_405b_fp8")` ### `llama_8b_fp8_list` `comptime llama_8b_fp8_list = List[TuningConfigSM90](TuningConfigSM90(128, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 6, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 6, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), Tuple[]())` ### `llama_8b_fp8_table` `comptime llama_8b_fp8_table = Table[TuningConfigSM90](llama_8b_fp8_list, "llama_8b_fp8")` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `MAX_M` `comptime MAX_M = Int.MAX` ## Functions * [​`matmul_dispatch_sm90`](./matmul_dispatch_sm90): * [​`matmul_dispatch_sm90_bf16_fp32`](./matmul_dispatch_sm90_bf16_fp32): * [​`matmul_dispatch_sm90_fp8`](./matmul_dispatch_sm90_fp8):
--- ## matmul_dispatch_sm90
`matmul_dispatch_sm90[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## matmul_dispatch_sm90_bf16_fp32
`matmul_dispatch_sm90_bf16_fp32[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## matmul_dispatch_sm90_fp8
`matmul_dispatch_sm90_fp8[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## default_config_sm90
`default_config_sm90[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, wgmma_shape: IndexList[3]]() -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## grouped_matmul_sm90
`grouped_matmul_sm90[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, wgmma_shape: IndexList[3] = Index(64, 256, 16), config: MatmulConfig[a_type, b_type, c_type, transpose_b] = default_config_sm90[a_type, b_type, c_type, transpose_b, wgmma_shape](), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul (3)
## Functions * [​`default_config_sm90`](./default_config_sm90): * [​`grouped_matmul_sm90`](./grouped_matmul_sm90):
--- ## sm90
Provides the Nvidia Hopper backend implementations for matmuls. ## Modules * [​`config`](./config/): * [​`dispatch`](./dispatch/): * [​`grouped_matmul`](./grouped_matmul/): * [​`matmul`](./matmul/): * [​`matmul_kernel_persistent`](./matmul_kernel_persistent/): * [​`matmul_kernels`](./matmul_kernels/): * [​`matmul_output`](./matmul_output/): * [​`testbed`](./testbed/): * [​`testbed_swapAB`](./testbed_swapAB/): Testbed for comparing swapAB vs normal matmul execution. * [​`tile_loader`](./tile_loader/): TileLoader module for efficient tile loading in GPU matrix multiplication. * [​`tile_writer`](./tile_writer/): TileWriter module for efficient tile writing in GPU matrix multiplication. * [​`tuning_configs`](./tuning_configs/):
--- ## matmul (5)
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`warp_specialize_gemm_with_multicasting`](./warp_specialize_gemm_with_multicasting): Unified dispatcher for all matmul kernel variants. * [​`warp_specialize_gemm_with_multicasting_splitk`](./warp_specialize_gemm_with_multicasting_splitk):
--- ## warp_specialize_gemm_with_multicasting
`warp_specialize_gemm_with_multicasting[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], grid_shape: OptionalReg[IndexList[2]] = None, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, schedule: MatmulSchedule = MatmulSchedule.NONE, hilbert_swizzle: Bool = False, splits: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, swapAB: Bool = False](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)` Unified dispatcher for all matmul kernel variants.
--- ## warp_specialize_gemm_with_multicasting_splitk
`warp_specialize_gemm_with_multicasting_splitk[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], splits: Int, raster_order: RasterOrder, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)`
--- ## matmul_kernel_persistent
--- ## HopperMatmulSM90Kernel
`struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, hilbert_swizzle: Bool = False, k_group_size: Int = 1, swapAB: Bool = False]` Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs. This kernel implements a highly optimized matrix multiplication (GEMM) using: * Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers * Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores * Multi-stage software pipelining for overlapping compute and memory operations * Producer-consumer model with separate warp groups for loading and computing Template Parameters: a\_type, b\_type, c\_type: Data types for input and output matrices a\_layout, b\_layout, c\_layout: Memory layouts for matrices c\_smem\_layout: Shared memory layout for output tile block\_tile\_shape: Tile dimensions \[M, N, K] processed by each thread block wgmma\_shape: Dimensions for each WGMMA instruction \[M, N, K] cluster\_shape: Thread block cluster dimensions for distributed shared memory num\_pipeline\_stages: Number of stages in the software pipeline (typically 3-7) num\_threads: Number of threads per block (must be multiple of 128) transpose\_b: Whether B matrix is transposed (required to be True) a\_swizzle, b\_swizzle: Memory swizzling for bank-conflict-free access c\_swizzle: Swizzling for output writes partitioned\_multicast: Enable partitioned multicast for large tiles use\_tma\_store: Use TMA for storing output (vs regular stores) promotion\_frequency: How often to promote FP8 accumulation to higher precision pdl\_level: Programmatic Dependency Launch (PDL) level elementwise\_lambda\_fn: Optional epilogue function elementwise\_compute\_lambda\_fn: Optional compute function hilbert\_swizzle: Use Hilbert curve for thread block scheduling ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, a_swizzle]()` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `AccumRegTile` `comptime AccumRegTile = LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `adjusted_num_pipeline_stages` `comptime adjusted_num_pipeline_stages = (num_pipeline_stages // k_group_size)` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, b_swizzle]()` ### `BK` `comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_frag_size` `comptime c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)` ### `cluster_size` `comptime cluster_size = Int.__init__[Int32](((cluster_shape.__getitem__[Int32, 3](0) * cluster_shape.__getitem__[Int32, 3](1)) * cluster_shape.__getitem__[Int32, 3](2)))` ### `num_consumer` `comptime num_consumer = ((num_threads // 128) - 1)` ### `num_consumer_threads` `comptime num_consumer_threads = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer * 128)` ### `num_m_mmas` `comptime num_m_mmas = ((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer)` ### `num_n_mmas` `comptime num_n_mmas = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // wgmma_shape.__getitem__[3, DType.int64, Int](1))` ### `SMem` `comptime SMem = HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, c_smem_layout.shape[0].value(), c_smem_layout.shape[1].value(), num_pipeline_stages, k_group_size]` ### `TMABarrier` `comptime TMABarrier = TMABarrierHandler[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.ATileArray.storage_size + HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.BTileArray.storage_size) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages)]` ### `WgmmaOp` `comptime WgmmaOp = TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]` ## Methods ### `validate_constraints` `static validate_constraints()` Validate common constraints for all kernel variants. ### `pipeline_init` `static pipeline_init()` Initialize pipeline synchronization barriers. This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization. For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier. ### `finalize_kernel` `static finalize_kernel()` Common finalization for all kernel variants. ### `multicast_mask` `static multicast_mask(rank_m: Scalar[DType.uint], rank_n: Scalar[DType.uint]) -> Tuple[Int32, Int32]` **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `common_kernel_init` `static common_kernel_init() -> Tuple[UInt, UInt, UInt, UInt, UInt, Bool]` Common initialization for all kernel variants. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple): Tuple of (warp\_group\_idx, warp\_group\_thread\_idx, rank\_m, rank\_n, warp\_id, lane\_predicate). ### `setup_producer` `static setup_producer() -> Int` Setup producer warp group by deallocating registers. **Returns:** [`Int`](/mojo/std/builtin/int/Int): Number of registers deallocated. ### `setup_consumer` `static setup_consumer(warp_group_idx: Scalar[DType.uint]) -> Tuple[UInt, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].AccumRegTile, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].AccumRegTile]` Setup consumer warp group. **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple): Tuple of (local\_warp\_group\_idx, c\_reg\_tile, final\_c\_reg\_tile). ### `consumer_arrive_empty_barriers` `static consumer_arrive_empty_barriers(warp_group_thread_idx: Scalar[DType.uint], mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages])` Signal initial empty barrier arrival for all pipeline stages. Must be called by consumer warp groups before the main loop so the producer knows it can start filling stages. ### `get_block_swizzle` `static get_block_swizzle(lut_ptr: LegacyUnsafePointer[UInt32] = LegacyUnsafePointer[True, UInt32, AddressSpace.GENERIC, MutAnyOrigin]()) -> IndexList[2, element_type=DType.uint32]` Calculate block swizzle for better L2 cache locality. **Args:** * ​lut\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Lookup table for Hilbert curve block scheduling (optional). **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): Swizzled block indices. ### `consumer_output` `static consumer_output[custom_elementwise_lambda_fn: Optional[elementwise_epilogue_type] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type, layout, desc_layout], c: LayoutTensor[c_type, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tile: TileTensor[c_type, Layout[ComptimeInt[c_smem_layout.shape[0].value()], ComptimeInt[c_smem_layout.shape[1].value()], ComptimeInt[c_smem_layout.shape[1].value()], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED], output_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], warp_group_thread_idx: Scalar[DType.uint], local_warp_group_idx: Scalar[DType.uint], local_thread_idx: Scalar[DType.uint], block_y: Int, block_x: Int)` Handle consumer output by writing GEMM results to global memory. ### `build_tma_loaders` `static build_tma_loaders[a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], rank_m: Scalar[DType.uint], rank_n: Scalar[DType.uint]) -> Tuple[TileLoaderTMA[a_tma_op, a_type, a_tile_layout, a_desc_layout, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape.__getitem__[Int32, 3](0), use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[b_tma_op, b_type, b_tile_layout, b_desc_layout, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape.__getitem__[Int32, 3](1), use_partitioned_multicast=partitioned_multicast]]` **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `build_cpasync_loaders` `static build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK // vector_size), thread_layout: Layout = Layout.row_major((WARPGROUP_SIZE // num_threads_per_row), num_threads_per_row)](a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]` **Returns:** [`Tuple`](/mojo/std/builtin/tuple/Tuple) ### `producer_main_loop_pipeline` `static producer_main_loop_pipeline[a_loader_type: TileLoader, b_loader_type: TileLoader, barrier_handler_type: BarrierHandler, //, num_k_iters: Int](m_coord: Scalar[DType.uint], n_coord: Scalar[DType.uint], k_coord: Scalar[DType.uint], a_loader: a_loader_type, b_loader: b_loader_type, barrier_handler: barrier_handler_type, mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]]](Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]](Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]())), Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]](Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]](Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]](Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]]](Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]](Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]())), Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]](Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]](Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]](Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))))), num_pipeline_stages])` ### `run` `static run[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], lut_ptr: LegacyUnsafePointer[UInt32])` Main kernel entry point for matrix multiplication. This kernel implements a producer-consumer pattern where: * One warp group (producer) loads tiles from global memory using TMA * Multiple warp groups (consumers) perform matrix multiplication using tensor cores The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix C. * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix B. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C. * ​lut\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Lookup table for Hilbert curve block scheduling (optional). ### `run_splitk` `static run_splitk[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], workspace_buffer: NDBuffer[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, 3, MutAnyOrigin], locks_ptr: LegacyUnsafePointer[UInt8], problem_shape: IndexList[3])` Split-K variant of the kernel for better load balancing on small problems. ### `run_grouped` `static run_grouped[a_tile_layout: Layout, b_tile_layout: Layout, c_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tile_layout, c_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])` Grouped matmul variant for MoE (Mixture of Experts) models. This variant handles multiple experts where each expert processes a subset of tokens. The a\_offsets array indicates token boundaries for each expert. ### `consumer_main_loop_pipeline` `static consumer_main_loop_pipeline[num_k_iters: Int](wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Scalar[DType.uint], final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]]](Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]](Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]())), Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]](Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]](Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]](Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]](Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]]](Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]](Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]())), Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]](Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())))), Coord[Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]](Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]](Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]](Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))))), num_pipeline_stages], warp_group_thread_idx: Scalar[DType.uint])` Pipeline-based consumer loop using ProducerConsumerPipeline. This is an alternative implementation of consumer\_main\_loop that uses the SM100 ProducerConsumerPipeline for synchronization instead of RingBuffer. **Args:** * ​wgmma\_op ([`TensorCoreAsync`](/mojo/kernels/layout/tensor_core_async/TensorCoreAsync)): Tensor core operator for matrix multiplication. * ​local\_warp\_group\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Index of this consumer warp group (0-based). * ​final\_c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Final accumulation register tile (for FP8 promotion). * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Working accumulation register tile. * ​pipeline ([`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline)): ProducerConsumerPipeline for synchronized tile access. * ​a\_tiles ([`SMemTileArrayWithLayout`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArrayWithLayout)): Tile array for A matrix in shared memory. * ​b\_tiles ([`SMemTileArrayWithLayout`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/tile_types/SMemTileArrayWithLayout)): Tile array for B matrix in shared memory. * ​warp\_group\_thread\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Thread index within the warp group. ### `promote_to_cuda_cores` `static promote_to_cuda_cores(c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])` Promote FP8 accumulation to higher precision using CUDA cores. When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores. This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Current accumulation from tensor cores. * ​final\_c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Higher-precision accumulator (updated in place). ### `wgmma` `static wgmma(wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Scalar[DType.uint], a_tile: TileTensor[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[a_type]())], ComptimeInt[(8 * (128 // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED], b_tile: TileTensor[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]], Coord[ComptimeInt[(128 // size_of[b_type]())], ComptimeInt[(8 * (128 // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1)._mlir_value else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])`
--- ## HopperMatmulSM90Kernel_SMem
`struct HopperMatmulSM90Kernel_SMem[a_type: DType, b_type: DType, c_type: DType, BM: Int, BN: Int, BK: Int, WG_BM: Int, WG_BN: Int, num_pipeline_stages: Int, k_group_size: Int, swizzle_bytes: Int = 128]` Shared memory layout for Hopper SM90 matrix multiplication kernel. This struct manages the shared memory allocation for: * Input tiles (A and B matrices) with multi-stage pipelining * Output tile (C matrix) for accumulation * Synchronization barriers for producer-consumer coordination The memory is organized to support asynchronous loads and efficient bank-conflict-free access patterns for tensor core operations. All tiles use TileTensor-based types from tile\_types.mojo. At TMA/WGMMA boundaries, pass {tile.ptr} to construct LayoutTensor. ## Fields * ​a\_tiles\_storage (`HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].ATileArray.Storage`): * ​b\_tiles\_storage (`HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].BTileArray.Storage`): * ​c\_tile\_storage (`HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].CTileArray.Storage`): * ​barriers (`BarrierPair[(num_pipeline_stages // k_group_size)]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `ATileArray` `comptime ATileArray = SMemTileArrayWithLayout[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BM // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[((BK * size_of[a_type]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(BM // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[((BK * size_of[a_type]()) // swizzle_bytes)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(BM // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[((BK * size_of[a_type]()) // swizzle_bytes)]]](Coord[ComptimeInt[8], ComptimeInt[(BM // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(BM // 8)]](Idx[8](), Idx[(BM // 8)]())), Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[((BK * size_of[a_type]()) // swizzle_bytes)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[((BK * size_of[a_type]()) // swizzle_bytes)]](Idx[(swizzle_bytes // size_of[a_type]())](), Idx[((BK * size_of[a_type]()) // swizzle_bytes)]())))), Coord[Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]]](Coord[ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[a_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[a_type]()))]](Idx[(swizzle_bytes // size_of[a_type]())](), Idx[(8 * (swizzle_bytes // size_of[a_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]](Idx[1](), Idx[0 if (((BK * size_of[a_type]()) // swizzle_bytes) == 1)._mlir_value else (BM * (swizzle_bytes // size_of[a_type]()))]()))))), num_pipeline_stages]` ### `BTileArray` `comptime BTileArray = SMemTileArrayWithLayout[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BN // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[((BK * size_of[b_type]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]]](Coord[Coord[ComptimeInt[8], ComptimeInt[(BN // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[((BK * size_of[b_type]()) // swizzle_bytes)]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[8], ComptimeInt[(BN // 8)]], Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[((BK * size_of[b_type]()) // swizzle_bytes)]]](Coord[ComptimeInt[8], ComptimeInt[(BN // 8)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[8], ComptimeInt[(BN // 8)]](Idx[8](), Idx[(BN // 8)]())), Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[((BK * size_of[b_type]()) // swizzle_bytes)]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[((BK * size_of[b_type]()) // swizzle_bytes)]](Idx[(swizzle_bytes // size_of[b_type]())](), Idx[((BK * size_of[b_type]()) // swizzle_bytes)]())))), Coord[Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]]](VariadicPack[True, MutExternalOrigin, True, CoordLike, Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]]](Coord[ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[(swizzle_bytes // size_of[b_type]())], ComptimeInt[(8 * (swizzle_bytes // size_of[b_type]()))]](Idx[(swizzle_bytes // size_of[b_type]())](), Idx[(8 * (swizzle_bytes // size_of[b_type]()))]())), Coord[ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]](VariadicPack[True, MutExternalOrigin, True, CoordLike, ComptimeInt[1], ComptimeInt[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]](Idx[1](), Idx[0 if (((BK * size_of[b_type]()) // swizzle_bytes) == 1)._mlir_value else (BN * (swizzle_bytes // size_of[b_type]()))]()))))), num_pipeline_stages]` ### `CTile` `comptime CTile = HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].CTileArray.Tile` ### `CTileArray` `comptime CTileArray = SMemTileArray2DRowMajor[c_type, WG_BM, WG_BN, 1]` ## Methods ### `a_tiles` `a_tiles(ref[AddressSpace._value._mlir_value] self) -> HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].ATileArray` Get A tile array accessor (TileTensor-based). **Returns:** `HopperMatmulSM90Kernel_SMem` ### `b_tiles` `b_tiles(ref[AddressSpace._value._mlir_value] self) -> HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].BTileArray` Get B tile array accessor (TileTensor-based). **Returns:** `HopperMatmulSM90Kernel_SMem` ### `c_tile` `c_tile(ref[AddressSpace._value._mlir_value] self) -> HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, BM, BN, BK, WG_BM, WG_BN, num_pipeline_stages, k_group_size, swizzle_bytes].CTile` Get C tile accessor (TileTensor-based). **Returns:** `HopperMatmulSM90Kernel_SMem` ### `create_pipeline` `create_pipeline(ref[AddressSpace._value._mlir_value] self) -> ProducerConsumerPipeline[(num_pipeline_stages // k_group_size)]` Create producer-consumer pipeline from barrier storage. **Returns:** [`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/structured_kernels/pipeline/ProducerConsumerPipeline) ### `pipeline_storage_size` `static pipeline_storage_size() -> Int` Calculate the memory size for all pipeline stages. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `output_storage_size` `static output_storage_size() -> Int` Calculate the memory size for output tile. **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `storage_size` `static storage_size() -> Int` Calculate the total storage size. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## find_K_alignment_upto_16B
`find_K_alignment_upto_16B(row_bytes_arg: Int) -> Int` Find alignment among 1B, 2B, 4B, 16B based on the row's bytes. This function determines the largest power-of-2 alignment (up to 16 bytes) that evenly divides the given row size. This is used to determine the optimal vector size for cp.async operations when K dimension alignment doesn't meet TMA requirements. **Args:** * ​row\_bytes\_arg ([`Int`](/mojo/std/builtin/int/Int)): Number of bytes in a row (K \* sizeof(element)). **Returns:** [`Int`](/mojo/std/builtin/int/Int): Alignment in bytes (1, 2, 4, 8, or 16).
--- ## matmul_kernels (Matmul_kernels)
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`HopperMatmulSM90Kernel`](./HopperMatmulSM90Kernel): Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs. * [​`HopperMatmulSM90Kernel_SMem`](./HopperMatmulSM90Kernel_SMem): Shared memory layout for Hopper SM90 matrix multiplication kernel. ## Functions * [​`find_K_alignment_upto_16B`](./find_K_alignment_upto_16B): Find alignment among 1B, 2B, 4B, 16B based on the row's bytes.
--- ## MatmulTileWriter
`@register_passable(trivial)` `struct MatmulTileWriter[dtype: DType, layout: Layout, address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, smem_tile_layout: Layout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[3], num_consumer: Int = 1, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, swapAB: Bool = False]` ## Fields * ​tensor (`MatmulTileWriter[BM=BM, BN=BN, swizzle=swizzle, wgmma_shape=wgmma_shape, num_consumer=num_consumer, use_tma_store=use_tma_store, elementwise_lambda_fn=elementwise_lambda_fn, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, swapAB=swapAB].CTensorType`): * ​smem\_tile (`LayoutTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]`): * ​warp\_group\_thread\_idx (`UInt`): * ​local\_warp\_group\_idx (`UInt`): * ​local\_thread\_idx (`UInt`): * ​block\_y (`Int`): * ​block\_x (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CTensorType` `comptime CTensorType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `frag_size` `comptime frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // WARPGROUP_SIZE)` ### `lambda_type` `comptime lambda_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], mut SIMD[dtype, width]) capturing -> None` ### `N` `comptime N = layout.shape[1].value()` ### `num_consumer_threads` `comptime num_consumer_threads = (num_consumer * WARPGROUP_SIZE)` ### `num_m_mmas` `comptime num_m_mmas = ((BM // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // num_consumer)` ### `num_n_mmas` `comptime num_n_mmas = (BN // wgmma_shape.__getitem__[3, DType.int64, Int](1))` ### `simd_size` `comptime simd_size = simd_width_of[dtype]()` ### `WG_BM` `comptime WG_BM = smem_tile_layout.shape[0].value()` ### `WG_BN` `comptime WG_BN = smem_tile_layout.shape[1].value()` ## Methods ### `__init__` `__init__(tensor: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile: LayoutTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: Scalar[DType.uint], local_warp_group_idx: Scalar[DType.uint], local_thread_idx: Scalar[DType.uint], block_y: Int, block_x: Int) -> Self` ### `write_tile` `write_tile[tma_layout: Layout, desc_layout: Layout, accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_layout, desc_layout], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL])` Write output from registers to global memory. Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.
--- ## matmul_output
## Structs * [​`MatmulTileWriter`](./MatmulTileWriter):
--- ## testbed
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`test_matmul_sm90`](./test_matmul_sm90):
--- ## test_matmul_sm90
`test_matmul_sm90[a_type: DType, b_type: DType, c_type: DType, cluster_shape: IndexList[3], block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], num_consumer: Int = 1, num_pipeline_stages: Int = 4, transpose_b: Bool = True, partitioned_multicast: Bool = False, grid_shape: OptionalReg[IndexList[2]] = None, use_tma_store: Bool = False, schedule: MatmulSchedule = MatmulSchedule.NONE, default_epilogue: Bool = False, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, measure_threshold: Optional[Float64] = None, backend: Backend = Backend.CUBLAS, k_group_size: Int = 1](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])`
--- ## testbed_swapAB
Testbed for comparing swapAB vs normal matmul execution. swapAB is an internal optimization where: * A and B operands are swapped inside the kernel * The output C tile is transposed on write-out * The final result C\[M,N] should be identical to the normal kernel Both kernels compute: C\[M,N] = A\[M,K] @ B\[N,K]^T The swapAB version just does it via: (B @ A^T)^T stored transposed = A @ B^T ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`test_matmul_sm90_swapAB_comparison`](./test_matmul_sm90_swapAB_comparison): Compare matmul results between normal execution and swapAB execution. * [​`test_matmul_sm90_swapAB_comparison_v2`](./test_matmul_sm90_swapAB_comparison_v2): Compare matmul results between normal execution and swapAB execution.
--- ## test_matmul_sm90_swapAB_comparison
`test_matmul_sm90_swapAB_comparison[a_type: DType, b_type: DType, c_type: DType, config: MatmulConfig[a_type, b_type, c_type], config_swapAB: MatmulConfig[a_type, b_type, c_type]](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])` Compare matmul results between normal execution and swapAB execution. Both compute: C\[M,N] = A\[M,K] @ B\[N,K]^T swapAB internally swaps A/B and transposes C on store, but result should match. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context. * ​m ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The M dimension (can be static or dynamic). * ​n ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The N dimension (can be static or dynamic). * ​k ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The K dimension (can be static or dynamic).
--- ## test_matmul_sm90_swapAB_comparison_v2
`test_matmul_sm90_swapAB_comparison_v2[a_type: DType, b_type: DType, c_type: DType, BM: Int, BN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_pipeline_stages: Scalar[DType.uint], num_consumer: Scalar[DType.uint], k_group_size: Scalar[DType.uint] = 1, num_k_partitions: Scalar[DType.uint] = 1, partitioned_multicast: Bool = False, BM_SWAPAB: Int = BM, BN_SWAPAB: Int = BN, BK_SWAPAB: Int = BK, MMA_M_SWAPAB: Int = MMA_M, MMA_N_SWAPAB: Int = MMA_N, MMA_K_SWAPAB: Int = MMA_K, num_pipeline_stages_swapAB: Scalar[DType.uint] = num_pipeline_stages, num_consumer_swapAB: Scalar[DType.uint] = num_consumer, k_group_size_swapAB: Scalar[DType.uint] = k_group_size, num_k_partitions_swapAB: Scalar[DType.uint] = num_k_partitions, partitioned_multicast_swapAB: Bool = partitioned_multicast, use_vendor_reference: Bool = False, default_epilogue: Bool = False, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])` Compare matmul results between normal execution and swapAB execution. This version accepts config parameters directly as compile-time values and builds configs internally. Both compute: C\[M,N] = A\[M,K] @ B\[N,K]^T swapAB internally swaps A/B and transposes C on store, but result should match. **Parameters:** * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of matrix A. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of matrix B. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of output matrix C. * ​BM ([`Int`](/mojo/std/builtin/int/Int)): Block tile M dimension for normal kernel. * ​BN ([`Int`](/mojo/std/builtin/int/Int)): Block tile N dimension for normal kernel. * ​BK ([`Int`](/mojo/std/builtin/int/Int)): Block tile K dimension for normal kernel. * ​MMA\_M ([`Int`](/mojo/std/builtin/int/Int)): MMA M dimension for normal kernel. * ​MMA\_N ([`Int`](/mojo/std/builtin/int/Int)): MMA N dimension for normal kernel. * ​MMA\_K ([`Int`](/mojo/std/builtin/int/Int)): MMA K dimension for normal kernel. * ​num\_pipeline\_stages ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of pipeline stages for normal kernel. * ​num\_consumer ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of consumers for normal kernel. * ​k\_group\_size ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K group size for normal kernel. * ​num\_k\_partitions ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of K partitions for normal kernel. * ​partitioned\_multicast ([`Bool`](/mojo/std/builtin/bool/Bool)): Partitioned multicast for normal kernel. * ​BM\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): Block tile M dimension for swapAB kernel. * ​BN\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): Block tile N dimension for swapAB kernel. * ​BK\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): Block tile K dimension for swapAB kernel. * ​MMA\_M\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): MMA M dimension for swapAB kernel. * ​MMA\_N\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): MMA N dimension for swapAB kernel. * ​MMA\_K\_SWAPAB ([`Int`](/mojo/std/builtin/int/Int)): MMA K dimension for swapAB kernel. * ​num\_pipeline\_stages\_swapAB ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of pipeline stages for swapAB kernel. * ​num\_consumer\_swapAB ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of consumers for swapAB kernel. * ​k\_group\_size\_swapAB ([`Scalar`](/mojo/std/builtin/simd/#scalar)): K group size for swapAB kernel. * ​num\_k\_partitions\_swapAB ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Number of K partitions for swapAB kernel. * ​partitioned\_multicast\_swapAB ([`Bool`](/mojo/std/builtin/bool/Bool)): Partitioned multicast for swapAB kernel. * ​use\_vendor\_reference ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, use vendor matmul (cuBLAS) as reference instead of normal kernel. * ​default\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, use default epilogue function that stores directly to output tensor. * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional compute lambda function to apply to each element before storing. **Args:** * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): The device context. * ​m ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The M dimension (can be static or dynamic). * ​n ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The N dimension (can be static or dynamic). * ​k ([`ValOrDim`](/mojo/kernels/internal_utils/_utils/ValOrDim)): The K dimension (can be static or dynamic).
--- ## BarrierHandler
Handles barrier lifecycle for different transfer mechanisms. Separates barrier management from tile loading: * prepare\_stage: Called once before loading tiles for a stage. * complete\_stage: Called once after all tiles for a stage are loaded. TMA: prepare sets expected bytes, complete is noop (hardware signals). cp.async: prepare is noop, complete commits copies and signals arrival. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `prepare_stage` `prepare_stage(self: _Self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` Prepare barrier for incoming transfers. For TMA: sets expected transaction bytes. For cp.async: noop. **Args:** * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The stage's memory barrier. ### `complete_stage` `complete_stage(self: _Self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` Signal that all transfers for this stage are done. For TMA: noop (hardware auto-signals). For cp.async: commits pending copies and signals thread arrival. **Args:** * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): The stage's memory barrier. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## CPAsyncBarrierHandler
`@register_passable(trivial)` `struct CPAsyncBarrierHandler` The cp.async barrier handler: noop on prepare, arrives on complete. Initializes the pipeline on construction (phase=0, barrier counts). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`BarrierHandler`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/BarrierHandler), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__[num_stages: Int](mut pipeline: ProducerConsumerPipeline[num_stages], num_consumers: Int, cluster_size: Int) -> Self` ### `prepare_stage` `prepare_stage(self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` ### `complete_stage` `complete_stage(self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])`
--- ## TMABarrierHandler
`@register_passable(trivial)` `struct TMABarrierHandler[expected_bytes: Int]` TMA barrier handler: sets expected bytes on prepare, noop on complete. Initializes the pipeline on construction (phase=0, barrier counts). ## Parameters * ​expected\_bytes ([`Int`](/mojo/std/builtin/int/Int)): Total bytes expected per stage across all loaders. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`BarrierHandler`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/BarrierHandler), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__[num_stages: Int](mut pipeline: ProducerConsumerPipeline[num_stages], num_consumers: Int, cluster_size: Int) -> Self` ### `prepare_stage` `prepare_stage(self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` ### `complete_stage` `complete_stage(self, mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])`
--- ## TileLoader
Base trait for tile loading mechanisms in matrix multiplication. This trait defines the interface for loading tiles from global memory to shared memory, abstracting over different hardware mechanisms. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `load_tile` `load_tile(self: _Self, dst: LayoutTensor[_Self._dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], coords: Tuple[UInt, UInt])` Load a tile from global memory to shared memory. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory (must be 128-byte aligned). * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (row, column) in the source matrix. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## TileLoaderCPAsync
`@register_passable(trivial)` `struct TileLoaderCPAsync[dtype: DType, src_layout: Layout, thread_layout: Layout, swizzle_mode: TensorMapSwizzle, vector_size: Int]` Software-based tile loader using cp.async instructions. This loader uses CUDA's cp.async instructions for asynchronous memory transfers with manual bounds checking and shared memory swizzling for optimal bank conflict avoidance. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the elements being loaded. * ​src\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the source matrix in global memory. * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread arrangement for distributed copying. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/std/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling pattern for shared memory access. * ​vector\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements loaded per thread. ## Fields * ​src (`LayoutTensor[dtype, src_layout, MutAnyOrigin]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TileLoader`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/TileLoader), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(src: LayoutTensor[dtype, src_layout, MutAnyOrigin]) -> Self` Initialize the cp.async tile loader. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor in global memory. ### `load_tile` `load_tile(self, dst: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], coords: Tuple[UInt, UInt])` Load a tile using cp.async instructions. Extracts a tile from the source tensor and performs an asynchronous copy to shared memory with bounds checking and swizzling. Note: Unlike TMA, this method expects tile indices and handles the conversion to element offsets internally via the tile() method. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory. * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization (currently unused). * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile indices (row\_tile, col\_tile) in the source matrix.
--- ## TileLoaderTMA (Tile_loader)
`@register_passable(trivial)` `struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tile_layout: Layout, desc_layout: Layout, /, *, BK: Scalar[DType.uint], cluster_size: Int32, use_partitioned_multicast: Bool]` TMA-based tile loader for hardware-accelerated memory transfers. This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters. ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin type for the TMA operation. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the elements being loaded. * ​tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the complete tile in shared memory. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout described by the TMA descriptor (may be smaller). * ​BK ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Block size in the K dimension (for coordinate conversion). * ​cluster\_size ([`Int32`](/mojo/std/builtin/simd/#int32)): Number of blocks in the cluster (1 for no clustering). * ​use\_partitioned\_multicast ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to use partitioned multicast loading. ## Fields * ​tma\_op (`TileLoaderTMA[tma_origin, dtype, tile_layout, desc_layout, BK=BK, cluster_size=cluster_size, use_partitioned_multicast=use_partitioned_multicast].TMATensorTilePtr`): * ​rank (`UInt`): * ​multicast\_mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TileLoader`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/TileLoader), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TMATensorTilePtr` `comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin], rank: Scalar[DType.uint], multicast_mask: UInt16) -> Self` Initialize the TMA tile loader. **Args:** * ​tma\_op ([`Pointer`](/mojo/std/memory/pointer/Pointer)): Pointer to the TMA tensor descriptor. * ​rank ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Rank of this block within the cluster. * ​multicast\_mask ([`UInt16`](/mojo/std/builtin/simd/#uint16)): Bit mask for multicast targets. ### `load_tile` `load_tile(self, dst: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], _coords: Tuple[UInt, UInt])` Load a tile using TMA hardware acceleration. Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks. Note: Coordinates are converted from (row, col) tile indices to (k\_elements, row/col\_elements) for TMA's K-major ordering. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory. * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization. * ​\_coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (row\_tile\_idx, col\_tile\_idx).
--- ## async_copy_with_bound_check
`async_copy_with_bound_check[dtype: DType, src_layout: Layout, dst_layout: Layout, //, thread_layout: Layout, swizzle_mode: TensorMapSwizzle](src: LayoutTensor[dtype, src_layout, MutAnyOrigin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Helper function for cp.async with boundary checking. This method performs element-wise async copies with per-element boundary checking. Out-of-bounds accesses are automatically zero-filled, ensuring safe operation near matrix edges. The method also handles shared memory swizzling to avoid bank conflicts and maximize memory bandwidth utilization. Template Parameters: dtype: Data type of the elements. src\_layout: Layout of the source tile. dst\_layout: Layout of the destination tile. thread\_layout: Thread arrangement for distributed copying. swizzle\_mode: Swizzling pattern for bank conflict avoidance. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor fragment in global memory. * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor fragment in shared memory.
--- ## tile_loader (Tile_loader)
TileLoader module for efficient tile loading in GPU matrix multiplication. This module provides utilities for loading matrix tiles from global memory to shared memory using two different mechanisms: 1. TMA (Tensor Memory Accelerator): Hardware-accelerated loads that can efficiently transfer 2D tiles with multicast support for multi-block clusters. 2. cp.async: Software-based asynchronous copy instructions with manual bounds checking and swizzling for optimal shared memory access patterns. The TileLoader struct abstracts these loading mechanisms to provide a unified interface for the matmul kernel's producer threads. ## Structs * [​`CPAsyncBarrierHandler`](./CPAsyncBarrierHandler): The cp.async barrier handler: noop on prepare, arrives on complete. * [​`TileLoaderCPAsync`](./TileLoaderCPAsync): Software-based tile loader using cp.async instructions. * [​`TileLoaderTMA`](./TileLoaderTMA): TMA-based tile loader for hardware-accelerated memory transfers. * [​`TMABarrierHandler`](./TMABarrierHandler): TMA barrier handler: sets expected bytes on prepare, noop on complete. ## Traits * [​`BarrierHandler`](./BarrierHandler): Handles barrier lifecycle for different transfer mechanisms. * [​`TileLoader`](./TileLoader): Base trait for tile loading mechanisms in matrix multiplication. ## Functions * [​`async_copy_with_bound_check`](./async_copy_with_bound_check): Helper function for cp.async with boundary checking.
--- ## FragmentToSMemWriter
`@register_passable(trivial)` `struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int, swapAB: Bool = False]` Writes WGMMA accumulator results from registers to shared memory using st.matrix. Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG\_BN output tiles. ## Parameters * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output data type (must be bfloat16 for st.matrix). * ​c\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the entire shared memory region. * ​tile\_n\_size ([`Int`](/mojo/std/builtin/int/Int)): Width of each output tile (typically TMA\_BN). * ​num\_m\_mmas ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA operations in M dimension. * ​num\_consumer ([`Int`](/mojo/std/builtin/int/Int)): Number of consumer warp groups. * ​half\_tile ([`Bool`](/mojo/std/builtin/bool/Bool)): Special mode for handling partial tiles. * ​WG\_BM ([`Int`](/mojo/std/builtin/int/Int)): Warp group tile height. * ​WG\_BN ([`Int`](/mojo/std/builtin/int/Int)): Warp group tile width. * ​sub\_wg\_id ([`Int`](/mojo/std/builtin/int/Int)): Which portion of WG\_BN this instance handles. * ​swapAB ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to swap the A and B matrices. ## Fields * ​c\_tile (`LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]`): * ​warp\_group\_thread\_idx (`UInt`): * ​local\_warp\_group\_idx (`UInt`): * ​st\_matrix\_rt\_layout (`FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_rt_layout_type`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/RegTileWriter), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `st_matrix_layout` `comptime st_matrix_layout = Layout.row_major(WG_BM, tile_n_size) if swapAB.__invert__()._mlir_value else Layout.row_major(tile_n_size, WG_BN)` ### `st_matrix_layout_regular` `comptime st_matrix_layout_regular = st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()` ### `st_matrix_layout_transpose` `comptime st_matrix_layout_transpose = st_matrix_m_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()` ### `st_matrix_rt_layout_type` `comptime st_matrix_rt_layout_type = RuntimeLayout[FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_regular if swapAB.__invert__()._mlir_value else FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_transpose, element_type=DType.int32, linear_idx_type=DType.int32]` ### `st_matrix_swizzle` `comptime st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size if swapAB.__invert__()._mlir_value else WG_BN, log2_floor((16 // size_of[c_type]()))]()` ## Methods ### `__init__` `__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: Scalar[DType.uint], local_warp_group_idx: Scalar[DType.uint]) -> Self` Initialize the fragment writer. **Args:** * ​c\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile to write to. * ​warp\_group\_thread\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Thread index within the warp group. * ​local\_warp\_group\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Sub-warp group index (divides N-dimension work). ### `write_tile` `write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write accumulator tile from registers to shared memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register tile containing MMA results. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile position (row\_idx, col\_idx) in output.
--- ## RegTileWriter
Base trait for tile writing mechanisms in matrix multiplication. This trait defines the interface for writing register tiles to memory (either shared memory or global memory). ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `write_tile` `write_tile(self: _Self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write a register tile to memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source register tile containing accumulator values. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## RegisterToGMemWriter
`@register_passable(trivial)` `struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: Optional[elementwise_epilogue_type] = None, compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, check_runtime_bounds: Bool = False, swapAB: Bool = False]` Writer for transferring accumulator registers directly to global memory. This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking. ## Parameters * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output data type. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the destination tensor. * ​dst\_address\_space ([`AddressSpace`](/mojo/std/memory/pointer/AddressSpace)): Address space of the destination tensor. * ​dst\_element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Element layout of the destination tensor. * ​dst\_layout\_int\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Integer type for destination layout indices. * ​dst\_linear\_idx\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Linear index type for destination tensor. * ​dst\_masked ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the destination tensor is masked. * ​dst\_alignment ([`Int`](/mojo/std/builtin/int/Int)): Alignment requirement for destination tensor. * ​wgmma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): Shape of the WGMMA operation \[M, N, K]. * ​num\_consumer ([`Int`](/mojo/std/builtin/int/Int)): Number of consumer warp groups. * ​N ([`Int`](/mojo/std/builtin/int/Int)): Matrix N dimension. * ​epilogue\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional epilogue function (mutates value in place). * ​compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional compute lambda function (returns new value). * ​check\_runtime\_bounds ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to perform bounds checking on N dimension. * ​swapAB ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to swap the A and B matrices. ## Fields * ​thread\_info (`ThreadInfo`): * ​dst (`RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].DstType`): * ​num\_m\_mmas (`Int`): * ​tile\_coords (`OptionalReg[TileCoordinates]`): * ​max\_row (`OptionalReg[UInt32]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/RegTileWriter), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `c_frag_size` `comptime c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // WARPGROUP_SIZE)` ### `DstType` `comptime DstType = LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]` ### `num_frag_mats` `comptime num_frag_mats = (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_n_frag_mat * RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_m_frag_mat)` ### `num_m_frag_mat` `comptime num_m_frag_mat = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8)` ### `num_n_frag_mat` `comptime num_n_frag_mat = (wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8)` ## Methods ### `__init__` `__init__(dst: LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: Scalar[DType.uint], num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, max_row: OptionalReg[UInt32] = None) -> Self` Initialize the register-to-global-memory writer. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in global memory. * ​warp\_group\_thread\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Thread index within the warp group. * ​num\_m\_mmas ([`Int`](/mojo/std/builtin/int/Int)): Number of MMA tiles in M dimension. * ​tile\_coords ([`OptionalReg`](/mojo/std/collections/optional/OptionalReg)): Optional tile coordinates for epilogue processing. * ​max\_row ([`OptionalReg`](/mojo/std/collections/optional/OptionalReg)): Optional maximum valid M coordinate (for epilogue). ### `write_tile` `write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write a single MMA tile from registers to global memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register tile containing accumulator values. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix.
--- ## SMemTileWriter
Base trait for tile writing mechanisms in matrix multiplication. This trait defines the interface for writing tiles from shared memory to global memory, abstracting over different hardware mechanisms. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `write_tile` `write_tile(self: _Self, src: LayoutTensor[_Self._dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile from shared memory to global memory. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory (must be 128-byte aligned). * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## ThreadInfo
`@register_passable(trivial)` `struct ThreadInfo` Thread identification within the warp group. ## Fields * ​warp\_id (`UInt`): * ​lane\_id (`UInt`): * ​lane\_row (`UInt32`): * ​lane\_col (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(warp_id: Scalar[DType.uint], lane_id: Scalar[DType.uint], lane_row: UInt32, lane_col: UInt32) -> Self` ### `from_warp_group_idx` `static from_warp_group_idx(warp_group_thread_idx: Scalar[DType.uint]) -> Self` Create ThreadInfo from a warp group thread index. **Args:** * ​warp\_group\_thread\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Thread index within the warp group. **Returns:** `Self`: ThreadInfo struct with computed warp\_id, lane\_id, lane\_row, and lane\_col.
--- ## TileCoordinates
`@register_passable(trivial)` `struct TileCoordinates` Helper struct for managing tile coordinate offsets. This struct encapsulates corner and split coordinates used in epilogue processing and provides a clean interface for coordinate transformations. ## Fields * ​corner (`IndexList[2]`): * ​split (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(corner: IndexList[2], split: IndexList[2]) -> Self` Initialize tile coordinates. **Args:** * ​corner ([`IndexList`](/mojo/std/utils/index_/IndexList)): Corner coordinates offset. * ​split ([`IndexList`](/mojo/std/utils/index_/IndexList)): Split coordinates offset. ### `adjust` `adjust(self, base_coords: IndexList[2]) -> IndexList[2]` Add corner and split offsets to base coordinates. **Args:** * ​base\_coords ([`IndexList`](/mojo/std/utils/index_/IndexList)): Base tile coordinates. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): Adjusted coordinates with corner and split offsets applied.
--- ## TileWriterTMA
`@register_passable(trivial)` `struct TileWriterTMA[tma_origin: ImmutOrigin, dtype: DType, tma_layout: Layout, desc_layout: Layout, //]` TMA-based tile writer for hardware-accelerated memory transfers. This writer uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from shared to global memory. ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin type for the TMA operation. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the elements being written. * ​tma\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the TMA tile for async store operations. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout described by the TMA descriptor. ## Fields * ​tma\_op (`TileWriterTMA.TMATensorTilePtr`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`SMemTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/SMemTileWriter), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TMATensorTilePtr` `comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tma_layout, desc_layout], tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[TMATensorTile[dtype, tma_layout, desc_layout], tma_origin]) -> Self` Initialize the TMA tile writer. **Args:** * ​tma\_op ([`Pointer`](/mojo/std/memory/pointer/Pointer)): Pointer to the TMA tensor descriptor. ### `write_tile` `write_tile(self, src: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile using TMA hardware acceleration. Performs an asynchronous TMA store from shared memory to global memory. The operation includes proper fencing and synchronization. Note: Coordinates are expected in (N, M) order for column-major output. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile coordinates (col, row) in element space.
--- ## TileWriterThreadwise
`@register_passable(trivial)` `struct TileWriterThreadwise[dtype: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, thread_layout: Layout, simd_size: Int, half_tile: Bool = False, swapAB: Bool = False]` ## Fields * ​dst (`TileWriterThreadwise[thread_layout, simd_size, half_tile, swapAB].DstType`): * ​thread\_idx (`UInt`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`SMemTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/SMemTileWriter), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DstType` `comptime DstType = LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]` ## Methods ### `__init__` `__init__(dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], thread_idx: Scalar[DType.uint]) -> Self` Initialize the threadwise tile writer. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in global memory. * ​thread\_idx ([`Scalar`](/mojo/std/builtin/simd/#scalar)): Thread index within the consumer warp group. ### `write_tile` `write_tile(self, src: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile using thread-distributed stores. Each thread writes a portion of the tile with proper swizzling for optimal memory access patterns. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory. * ​coords ([`Tuple`](/mojo/std/builtin/tuple/Tuple)): Tile indices (row\_tile, col\_tile) in the destination matrix.
--- ## tile_writer
TileWriter module for efficient tile writing in GPU matrix multiplication. This module provides utilities for writing tiles to memory using different mechanisms and destinations: 1. Register → Shared Memory: Uses st.matrix hardware instruction for efficient storage of WGMMA accumulator results to shared memory with swizzling. 2. Register → Global Memory: Direct stores from register tiles to global memory with optional epilogue processing and bounds checking. 3. Shared Memory → Global Memory: Hardware-accelerated TMA stores or regular stores for efficient 2D tile transfers from shared to global memory. Two main traits abstract these writing mechanisms: * TileWriter: For shared memory → global memory transfers * RegTileWriter: For register → memory (shared or global) transfers ## Structs * [​`FragmentToSMemWriter`](./FragmentToSMemWriter): Writes WGMMA accumulator results from registers to shared memory using st.matrix. * [​`RegisterToGMemWriter`](./RegisterToGMemWriter): Writer for transferring accumulator registers directly to global memory. * [​`ThreadInfo`](./ThreadInfo): Thread identification within the warp group. * [​`TileCoordinates`](./TileCoordinates): Helper struct for managing tile coordinate offsets. * [​`TileWriterThreadwise`](./TileWriterThreadwise): * [​`TileWriterTMA`](./TileWriterTMA): TMA-based tile writer for hardware-accelerated memory transfers. ## Traits * [​`RegTileWriter`](./RegTileWriter): Base trait for tile writing mechanisms in matrix multiplication. * [​`SMemTileWriter`](./SMemTileWriter): Base trait for tile writing mechanisms in matrix multiplication.
--- ## TuningConfigSM90
`@register_passable(trivial)` `struct TuningConfigSM90` ## Fields * ​M (`Int`): * ​N (`Int`): * ​K (`Int`): * ​mma\_shape (`IndexList[3]`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_pipeline\_stages (`UInt`): * ​cluster\_shape (`IndexList[3]`): * ​num\_consumer (`UInt`): * ​partitioned\_multicast (`Bool`): * ​grid\_shape (`OptionalReg[IndexList[2]]`): * ​schedule (`MatmulSchedule`): * ​splits (`OptionalReg[Int]`): * ​raster\_order (`OptionalReg[RasterOrder]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(M: Int, N: Int, K: Int, mma_shape: IndexList[3], block_tile_shape: IndexList[3], num_pipeline_stages: Scalar[DType.uint], cluster_shape: IndexList[3], num_consumer: Scalar[DType.uint], partitioned_multicast: Bool, grid_shape: OptionalReg[IndexList[2]] = None, schedule: MatmulSchedule = MatmulSchedule.NONE, splits: OptionalReg[Int] = None, raster_order: OptionalReg[RasterOrder] = None) -> Self` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String)
--- ## tuning_configs (Tuning_configs)
## Structs * [​`TuningConfigSM90`](./TuningConfigSM90):
--- ## split_k_reduce
`split_k_reduce[c_type: DType, work_space_type: DType, c_layout: Layout, work_space_layout: Layout, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None](c: LayoutTensor[c_type, c_layout, origin], work_space: LayoutTensor[work_space_type, work_space_layout, origin], ctx: DeviceContext)`
--- ## MatmulSchedule
`@register_passable(trivial)` `struct MatmulSchedule` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DS_SCHEDULER` `comptime DS_SCHEDULER = MatmulSchedule(3)` ### `NONE` `comptime NONE = MatmulSchedule(0)` ### `TILE1D` `comptime TILE1D = MatmulSchedule(1)` ### `TILE2D` `comptime TILE2D = MatmulSchedule(2)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## RasterOrder (Tile_scheduler)
`@register_passable(trivial)` `struct RasterOrder` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`Hashable`](/mojo/std/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AlongM` `comptime AlongM = RasterOrder(1)` ### `AlongN` `comptime AlongN = RasterOrder(0)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## TileScheduler (4)
`@register_passable(trivial)` `struct TileScheduler[problem_shape: IndexList[3], tile_shape: IndexList[3], grid_shape: IndexList[2], cluster: IndexList[3] = Index(1, 1, 1), raster_dim: UInt32 = 1, schedule: MatmulSchedule = MatmulSchedule.TILE2D]` ## Fields * ​idx (`UInt32`): * ​prob\_shape (`IndexList[3]`): * ​num\_waves\_m (`UInt32`): * ​num\_waves\_n (`UInt32`): * ​log\_num\_waves\_n (`FastDiv[DType.uint32]`): * ​current\_iter (`Int`): * ​num\_aligned\_m\_blocks (`UInt32`): * ​num\_blocks (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `kNum1DBlocksPerGroup` `comptime kNum1DBlocksPerGroup = 16` ### `kNumNBlocks` `comptime kNumNBlocks = SIMD[DType.uint32, 1](ceildiv(problem_shape.__getitem__[3, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](1)))` ### `num_grids` `comptime num_grids = SIMD[DType.uint32, 1]((grid_shape.__getitem__[2, DType.int64, Int](0) * grid_shape.__getitem__[2, DType.int64, Int](1)))` ### `wave_shape` `comptime wave_shape = Index[dtype=DType.uint32]((tile_shape.__getitem__[3, DType.int64, Int](0) * grid_shape.__getitem__[2, DType.int64, Int](1)), (tile_shape.__getitem__[3, DType.int64, Int](1) * grid_shape.__getitem__[2, DType.int64, Int](0)))` ## Methods ### `__init__` `__init__(prob_shape: IndexList[3]) -> Self` ### `get_current_work_info` `get_current_work_info(mut self) -> WorkInfo` **Returns:** `WorkInfo` ### `advance` `advance(mut self)` ### `fetch_next_work` `fetch_next_work(mut self) -> WorkInfo` **Returns:** `WorkInfo` ### `num_output_tiles` `num_output_tiles(self) -> UInt` **Returns:** `UInt` ### `fetch_next_work_ds` `fetch_next_work_ds(mut self) -> WorkInfo` **Returns:** `WorkInfo`
--- ## WorkInfo (4)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​num\_k\_tiles (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `INVALID_WORK_INFO` `comptime INVALID_WORK_INFO = WorkInfo(0, 0, 0, 0, False)` ## Methods ### `__init__` `__init__() -> Self` ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `is_final_split` `is_final_split(self, k_tiles_per_output_tile: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `get_k_start` `get_k_start(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## tile_scheduler (3)
## Structs * [​`MatmulSchedule`](./MatmulSchedule): * [​`RasterOrder`](./RasterOrder): * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## ReductionMode
`@register_passable(trivial)` `struct ReductionMode` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Deterministic` `comptime Deterministic = ReductionMode(0)` ### `Nondeterministic` `comptime Nondeterministic = ReductionMode(1)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## SplitKTileScheduler
`@register_passable(trivial)` `struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode.Deterministic]` ## Fields * ​prob\_shape (`IndexList[3]`): * ​block\_id\_in\_cluster (`IndexList[2]`): * ​blocks\_per\_problem (`UInt32`): * ​current\_work\_linear\_idx (`UInt32`): * ​log\_cluster\_shape\_major (`UInt32`): * ​log\_cluster\_shape\_minor (`UInt32`): * ​cluster\_blk\_major (`UInt32`): * ​locks\_ptr (`LegacyUnsafePointer[Int32]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `k_tiles_per_output_tile` `comptime k_tiles_per_output_tile = SIMD[DType.uint32, 1](ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2)))` ### `k_tiles_per_split` `comptime k_tiles_per_split = (SIMD[DType.uint32, 1](ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))) // splits)` ### `log_cluster_size` `comptime log_cluster_size = log2_floor((cluster_shape.__getitem__[2, DType.int64, Int](0) * cluster_shape.__getitem__[2, DType.int64, Int](1)))` ### `WorkTileType` `comptime WorkTileType[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin]` #### Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ## Methods ### `__init__` `__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self` ### `get_sm_num` `get_sm_num(self) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `get_problem_blocks_shape` `static get_problem_blocks_shape(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `initial_work_tile_info` `initial_work_tile_info(mut self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `get_current_work_info` `get_current_work_info(mut self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `get_worktile_m_n_idx` `get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)` ### `assign_work` `assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)` ### `get_k_start_and_linear_tile_id` `get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `fetch_next_work` `fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `requires_reduction` `requires_reduction(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `advance_to_next_work` `advance_to_next_work(mut self)` ### `is_last_split` `is_last_split(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `get_grid_shape` `static get_grid_shape(dyn_cluster_shape: IndexList[3], dyn_raster_order: RasterOrder = RasterOrder.AlongN) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `get_num_tiles` `static get_num_tiles(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `get_required_locks_buffer_size_bytes` `static get_required_locks_buffer_size_bytes[accum_type: DType, dyn_num_consumer: UInt32](problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `get_linear_idx_from_m_and_n` `get_linear_idx_from_m_and_n(self, tile_m: UInt32, tile_n: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `output_tile_index` `output_tile_index(self, work_tile_info: WorkInfo) -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `reduction` `reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)` ### `wait_eq` `static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)` ### `wait_lt` `static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)` ### `arrive_set` `static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)` ### `store_accumulator` `store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)` ### `reduce_add` `reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)`
--- ## tile_scheduler_splitk (Tile_scheduler_splitk)
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`ReductionMode`](./ReductionMode): * [​`SplitKTileScheduler`](./SplitKTileScheduler):
--- ## matmul (6)
Provides the backend implementation for matmuls. ## Packages * [​`cpu`](./cpu/): Provides the CPU backend implementations for matmuls. * [​`gpu`](./gpu/): * [​`vendor`](./vendor/): Provides the Vendor backend implementations for matmuls. ## Functions * [​`matmul`](./matmul):
--- ## matmul (7)
`matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: Optional[DeviceContext])` `matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], ctx: DeviceContextPtr = DeviceContextPtr())` `matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], ctx: Optional[DeviceContext])`
--- ## Backend
`@register_passable(trivial)` `struct Backend` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`Equatable`](/mojo/std/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AUTOMATIC` `comptime AUTOMATIC = Backend(0)` ### `CUBLAS` `comptime CUBLAS = Backend(1)` ### `CUBLASLT` `comptime CUBLASLT = Backend(2)` ### `HIPBLASLT` `comptime HIPBLASLT = Backend(4)` ### `ROCBLAS` `comptime ROCBLAS = Backend(3)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__is__` `__is__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__int__` `__int__(self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## Handle
`struct Handle[backend: Backend = _resolve_backend[Backend.AUTOMATIC]()]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = _all_trivial_copyinit[cublasHandle_t, Handle, hipblasLtHandle_t]()` ### `__del__is_trivial` `comptime __del__is_trivial = _all_trivial_del[cublasHandle_t, Handle, hipblasLtHandle_t]()` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = _all_trivial_moveinit[cublasHandle_t, Handle, hipblasLtHandle_t]()` ### `resolved_backend` `comptime resolved_backend = _resolve_backend[backend]()` ### `type` `comptime type = Variant[cublasHandle_t, Handle, hipblasLtHandle_t]` ## Methods ### `__init__` `__init__(out self)` ### `__is__` `__is__(self, other: Backend) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Backend) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `__enter__` `__enter__(self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## blas
## `comptime` values ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`Backend`](./Backend): * [​`Handle`](./Handle): ## Functions * [​`matmul`](./matmul): Matmul using the vendor BLAS library. With a global handle.
--- ## matmul (Blas)
`matmul[use_tf32: Bool = False, *, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` Matmul using the vendor BLAS library. With a global handle. `matmul[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, *, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, c_tensor: LayoutTensor[c_type, c_layout, origin], a_tensor: LayoutTensor[a_type, a_layout, origin], b_tensor: LayoutTensor[b_type, b_layout, origin], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` `matmul[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, handle: Handle[backend], c_tensor: LayoutTensor[c_type, c_layout, origin], a_tensor: LayoutTensor[a_type, a_layout, origin], b_tensor: LayoutTensor[b_type, b_layout, origin], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` `matmul[use_tf32: Bool = False](ctx: DeviceContext, handle: Handle[backend], c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], *, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0)`
--- ## vendor (Vendor)
Provides the Vendor backend implementations for matmuls. This backend is used for testing and evaluation. ## Modules * [​`blas`](./blas/): * [​`matmul`](./matmul/):
--- ## matmul (8)
## Functions * [​`matmul`](./matmul): This implements the matmul kernel for the Blackwell architecture. Note that we do not currently have pure mojo kernels which would utilize blackwell architectures, so in place we just call the CUBLAS library.
--- ## matmul (9)
`matmul[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = False, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, config: Optional[MatmulConfig[a_type, b_type, c_type, transpose_b]] = None](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext)` This implements the matmul kernel for the Blackwell architecture. Note that we do not currently have pure mojo kernels which would utilize blackwell architectures, so in place we just call the CUBLAS library.
--- ## matrix_band_part
The module implements matrix band part functions. ## Functions * [​`matrix_band_part`](./matrix_band_part):
--- ## matrix_band_part (Matrix_band_part)
`matrix_band_part[dtype: DType, int_type: DType, cond_type: DType, rank: Int, input_0_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], simd_width: Int, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[rank], num_lower: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_upper: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], exclude: LayoutTensor[cond_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)`
--- ## BTileGenerator
`struct BTileGenerator[config: KernelConfig, a_type: DType, b_type: DType, c_type: DType, shape: DimList, transpose_b: Bool, b_packed: Bool, origin: ImmutOrigin]` Struct to encapsulate a tile of B that supports prepacking. If b\_packed is true, calls to get\_tile will return a buffer view from B. Otherwise, calls to get\_tile will copy a tile from B into a stack allocated scratch buffer and return a view of that. ## Fields * ​b (`NDBuffer[b_type, 2, origin, shape]`): * ​b\_tile\_stack\_ptr (`LegacyUnsafePointer[Scalar[b_type]]`): * ​tile\_n\_k (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `get` `static get(b: NDBuffer[b_type, 2, origin, shape], tile_n_k: IndexList[2]) -> Self` ### `get_tile` `get_tile[inner_size: Int](self, global_offset: GemmShape, tile_dim_nk: IndexList[2], valid_data_dim_nk: IndexList[2]) -> NDBuffer[b_type, 3, MutAnyOrigin, config.packed_shape]` Get a packed matrix (B) tile. valid\_data\_tile\_nk is ignored for pre-packing, where the tile is padded to have shape of tile\_dim\_nk. **Args:** * ​global\_offset ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Offset in the global M, N, K dimensions. * ​tile\_dim\_nk ([`IndexList`](/mojo/std/utils/index_/IndexList)): Tile shape based on cache size and matrix dimensions. * ​valid\_data\_dim\_nk ([`IndexList`](/mojo/std/utils/index_/IndexList)): The upper bounds for N and K dimensions. **Returns:** [`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer): A view of the packed tile.
--- ## PackMatrixCols
`struct PackMatrixCols[original_mut: Bool, //, original_shape: DimList, packed_shape: DimList, dtype: DType, simd_size: Int, column_inner_size: Int, use_vnni: Bool, use_i8mm: Bool, packed_origin: MutOrigin, original_origin: Origin[mut=original_mut]]` Pack columns from a matrix into the mlas packed layout and extract inner vectors of columns into the packed inner dimension, e.g. extracts \[X, Y] and packs as \[Yo]\[X]\[Yi]. ## Fields * ​packed\_matrix (`NDBuffer[dtype, 3, packed_origin, packed_shape]`): * ​original\_matrix (`NDBuffer[dtype, 2, original_origin, original_shape]`): * ​global\_offset (`IndexList[2]`): * ​pack\_tile\_dim (`IndexList[2]`): * ​valid\_data\_dim (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `run` `static run(packed_matrix: NDBuffer[dtype, 3, MutAnyOrigin, packed_shape], original_matrix: NDBuffer[dtype, 2, MutAnyOrigin, original_shape], global_offset: IndexList[2], pack_tile_dim: IndexList[2], valid_data_dim: IndexList[2])` Interface function to run the packing routine. Args: packed\_matrix(NDBuffer): pre-allocated buffer space for packed data. original\_matrix(NDBuffer): data buffer containing the original matrix to pack. global\_offset(IndexList): offset to use when indexing the original matrix. pack\_tile\_dim(IndexList): 2D dimension tuple describing the size of the packed tile. valid\_data\_dim(IndexList): 2D dimension tuple describing the amount of valid data on the global buffer starting from the offset.
--- ## PackMatrixRows
`struct PackMatrixRows[original_mut: Bool, //, original_shape: DimList, packed_shape: DimList, dtype: DType, simd_size: Int, row_inner_size: Int, packed_origin: MutOrigin, original_origin: Origin[mut=original_mut]]` Pack rows from a matrix into the mlas packed layout and extract inner vectors of rows into the packed inner dimension, e.g. extract tile \[X, Y] and pack into \[Xo]\[Y]\[Xi]. ## Fields * ​packed\_matrix (`NDBuffer[dtype, 3, packed_origin, packed_shape]`): * ​original\_matrix (`NDBuffer[dtype, 2, original_origin, original_shape]`): * ​global\_offset (`IndexList[2]`): * ​pack\_tile\_dim (`IndexList[2]`): * ​valid\_data\_dim (`IndexList[2]`): * ​valid\_simd\_dim (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `run` `static run(packed_matrix: NDBuffer[dtype, 3, packed_origin, packed_shape], original_matrix: NDBuffer[dtype, 2, original_origin, original_shape], global_offset: IndexList[2], pack_tile_dim: IndexList[2], valid_data_dim: IndexList[2])` Interface function to run the packing routine. Args: packed\_matrix(NDBuffer): pre-allocated buffer space for packed data. original\_matrix(NDBuffer): data buffer containing the original matrix to pack. global\_offset(IndexList): offset to use when indexing the original matrix. pack\_tile\_dim(IndexList): 2D dimension tuple describing the size of the packed tile. valid\_data\_dim(IndexList): 2D dimension tuple describing the amount of valid data on the global buffer starting from the offset.
--- ## packing
## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`BTileGenerator`](./BTileGenerator): Struct to encapsulate a tile of B that supports prepacking. * [​`PackMatrixCols`](./PackMatrixCols): Pack columns from a matrix into the mlas packed layout and extract inner vectors of columns into the packed inner dimension, e.g. extracts \[X, Y] and packs as \[Yo]\[X]\[Yi]. * [​`PackMatrixRows`](./PackMatrixRows): Pack rows from a matrix into the mlas packed layout and extract inner vectors of rows into the packed inner dimension, e.g. extract tile \[X, Y] and pack into \[Xo]\[Y]\[Xi]. ## Functions * [​`pack_b`](./pack_b): Utility function to pack the entire B matrix, such that each \[tile\_n // inner\_size, tile\_k, inner\_size] tile of src is contiguous in dst. * [​`pack_b_ndbuffer`](./pack_b_ndbuffer): * [​`pack_matmul_b_shape_func`](./pack_matmul_b_shape_func): * [​`pack_transposed_b_ndbuffer`](./pack_transposed_b_ndbuffer):
--- ## pack_b
`pack_b[transpose_b: Bool, simd_size: Int, inner_size: Int, a_type: DType, b_type: DType, c_type: DType, src_shape: DimList, dst_shape: DimList](dst: NDBuffer[b_type, 2, origin, dst_shape], src: NDBuffer[b_type, 2, origin, src_shape], tile_n: Int, tile_k: Int)` Utility function to pack the entire B matrix, such that each \[tile\_n // inner\_size, tile\_k, inner\_size] tile of src is contiguous in dst. Tiles (not tile contents) are stored in row major order, so tile\[i, j] is tile\_n \* tile\_k bytes away from tile\[i, j+1].
--- ## pack_b_ndbuffer
`pack_b_ndbuffer[b_mut: Bool, //, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList, b_origin: Origin[mut=b_mut], output_origin: MutOrigin](b_input: NDBuffer[b_type, 2, b_origin, b_shape], output_buffer: NDBuffer[b_type, 2, output_origin])`
--- ## pack_matmul_b_shape_func
`pack_matmul_b_shape_func[a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList, transpose_in_0: Bool, single_thread_blocking_override: Bool](b_input: NDBuffer[b_type, 2, origin, b_shape]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## pack_transposed_b_ndbuffer
`pack_transposed_b_ndbuffer[a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList](b_input: NDBuffer[b_type, 2, origin, b_shape], output_buffer: NDBuffer[b_type, 2, origin])`
--- ## apply_q
`apply_q[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], X: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Applies the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` to the `X` matrix. See `qr_factorization` for more details on the construction of the Householder reflector.
--- ## form_q
`form_q[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], Q: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Forms the Q factor from the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` and stores the result in `Q`.
--- ## qr_factorization
## Functions * [​`apply_q`](./apply_q): Applies the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` to the `X` matrix. * [​`form_q`](./form_q): Forms the Q factor from the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` and stores the result in `Q`. * [​`qr_factorization`](./qr_factorization): Performs QR factorization of a matrix `A` using the Householder reflector method.
--- ## qr_factorization (Qr_factorization)
`qr_factorization[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Performs QR factorization of a matrix `A` using the Householder reflector method. This function computes the QR factorization of matrix `A` in-place using Householder reflections. The result is stored directly in the input matrix `A`, with scaling factors in `sigma`. The implementation follows the LAPACK algorithm for generating Householder reflectors in-place. Algorithm: The Householder reflector is defined as: U = I - σww^H where: w = (x + νe₁)/ξ σ = ξ/ν ξ = x₀ + ν ν = sign(x₀)‖x‖₂ ``` This ensures that U^H x = -νe₁ and U^H U = I. ``` References: \[1] Lehoucq, R. B. (1996). The computation of elementary unitary matrices. ACM Transactions on Mathematical Software, 22(4), 393-400. Note: There is a typo in reference \[lawn72]. The correct result is U^H x = -νe₁.
--- ## IteratorScatterGatherAmd
`struct IteratorScatterGatherAmd[thread_layout: Layout, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1]` Iterator-based AMD scatter-gather for DRAM-register data movement. ## Parameters * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread organization layout. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total threads (defaults to thread\_layout size). * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Thread execution scope (block or warp). * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): Number of block dimensions. ## Fields * ​buffer (`AMDBufferResource`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], tensor_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Initialize with tensor and iterator. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Layout tensor for bounds. * ​tensor\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): Iterator for AMD buffer resource. ### `copy` `copy(self, dst_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_gmem_tile_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Copy DRAM to registers via iterator. **Args:** * ​dst\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination register tile. * ​src\_gmem\_tile\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): Source memory iterator.
--- ## NVIDIASharedMemoryBasePtr
`struct NVIDIASharedMemoryBasePtr[name: StringSlice[StaticConstantOrigin] = "extern_ptr_syml", memory_alignment: Int = 8]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`SharedMemoryBasePtr`](/mojo/kernels/linalg/structuring/SharedMemoryBasePtr) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `alignment` `comptime alignment = 128` ## Methods ### `ptr` `static ptr() -> LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)
--- ## SMemArray
`@register_passable(trivial)` `struct SMemArray[type: __TypeOfAllTypes, size: Int]` Shared memory array of fixed size. ## Parameters * ​type ([`__TypeOfAllTypes`](/mojo/std/builtin/type_aliases/#__typeofalltypes)): Element type. * ​size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements. ## Fields * ​ptr (`SMemArray[type, size].ptr_type`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ptr_type` `comptime ptr_type = LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]` ### `Storage` `comptime Storage = InlineArray[type, size]` ### `storage_size` `comptime storage_size = (size * size_of[type]())` ## Methods ### `__init__` `__init__(unsafe_ptr: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]) -> Self` Initialize with shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Shared memory pointer. `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[type, size]) -> Self` Initialize from Storage. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemArray[type, size].ptr_type` Get a pointer to the element at index. **Args:** * ​index (`T`): Element index. **Returns:** `SMemArray`: Pointer to element. ### `len` `static len() -> Int` Get array length in bytes. **Returns:** [`Int`](/mojo/std/builtin/int/Int): Total size in bytes. ### `stack_allocation` `static stack_allocation[alignment: Int = align_of[type]()]() -> Self`
--- ## SMemTileArray (Structuring)
`@register_passable(trivial)` `struct SMemTileArray[dtype: DType, layout: Layout, num_tiles: Int, alignment: Int]` Array of tiles in shared memory. ## Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): Tile data type. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Tile layout configuration. * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): Number of tiles. * ​alignment ([`Int`](/mojo/std/builtin/int/Int)): Memory alignment. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_elements` `comptime num_elements = (layout.size() * num_tiles)` ### `Storage` `comptime Storage = InlineArray[Scalar[dtype], SMemTileArray[dtype, layout, num_tiles, alignment].num_elements]` ### `storage_size` `comptime storage_size = (SMemTileArray[dtype, layout, num_tiles, alignment].num_elements * size_of[dtype]())` ### `Tile` `comptime Tile = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ## Methods ### `__init__` `__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArray[dtype, layout, num_tiles, alignment].num_elements]) -> Self` Initialize with Storage. **Args:** * ​storage ([`InlineArray`](/mojo/std/collections/inline_array/InlineArray)): Storage. `__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, origin=origin]) -> Self` Initialize with shared memory pointer. **Args:** * ​unsafe\_ptr ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Shared memory pointer. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArray[dtype, layout, num_tiles, alignment].Tile` Get tile at index. **Args:** * ​index (`T`): Tile index. **Returns:** `SMemTileArray`: Tile at index. ### `slice` `slice[length: Int](self, start: Int) -> SMemTileArray[dtype, layout, length, alignment]` **Returns:** `SMemTileArray` ### `stack_allocation` `static stack_allocation() -> Self`
--- ## ScatterGatherAmd
`struct ScatterGatherAmd[thread_layout: Layout, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1]` AMD tile-based scatter-gather for DRAM-register data movement. ## Parameters * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread organization layout. * ​num\_threads ([`Int`](/mojo/std/builtin/int/Int)): Total threads (defaults to thread\_layout size). * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Thread execution scope (block or warp). * ​block\_dim\_count ([`Int`](/mojo/std/builtin/int/Int)): Number of block dimensions. ## Fields * ​buffer (`AMDBufferResource`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Initialize with a tensor. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Layout tensor for AMD buffer resource creation. ### `copy` `copy(self, dst_reg_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_gmem_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: Optional[UInt] = None)` Copy DRAM to registers. **Args:** * ​dst\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination register tile. * ​src\_gmem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source global memory tile. * ​offset ([`Optional`](/mojo/std/collections/optional/Optional)): Optional copy offset. `copy(self, dst_gmem_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_reg_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Copy registers to DRAM. **Args:** * ​dst\_gmem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination global memory tile. * ​src\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source register tile.
--- ## SharedMemoryBasePtr
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `alignment` `comptime alignment` ## Required methods ### `ptr` `static ptr() -> LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)
--- ## SharedMemoryManager
`struct SharedMemoryManager[SMBP: SharedMemoryBasePtr]` ## Fields * ​base\_ptr (`LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]`): * ​offset (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `Array` `comptime Array[type: __TypeOfAllTypes, size: Int] = SMemArray[type, size]` #### Parameters * ​type ([`__TypeOfAllTypes`](/mojo/std/builtin/type_aliases/#__typeofalltypes)): * ​size ([`Int`](/mojo/std/builtin/int/Int)): ### `Tile` `comptime Tile[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=SMBP.alignment]` #### Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ### `TileArray` `comptime TileArray[dtype: DType, layout: Layout, num_tiles: Int] = SMemTileArray[dtype, layout, num_tiles, SMBP.alignment]` #### Parameters * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): * ​num\_tiles ([`Int`](/mojo/std/builtin/int/Int)): ## Methods ### `__init__` `__init__(out self)` Initialize the shared memory manager. ### `build` `build[dtype: DType, layout: Layout, //, T: AnyStruct[LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=SMBP.alignment]]](mut self) -> T` Allocate a single tile. **Returns:** `T`: Allocated tile. `build[dtype: DType, layout: Layout, num_tiles: Int, //, T: AnyStruct[SMemTileArray[dtype, layout, num_tiles, SMBP.alignment]]](mut self) -> T` Allocate a tile array. **Returns:** `T`: Allocated tile array. `build[type: __TypeOfAllTypes, size: Int, //, T: AnyStruct[SMemArray[type, size]]](mut self) -> T` Allocate a regular array. **Returns:** `T`: Allocated array.
--- ## structuring
## `comptime` values ### `eval` `comptime eval[T: AnyType, //, val: T] = val` Helper alias to force evaluation of expressions at compile time. #### Parameters * ​T ([`AnyType`](/std/builtin/anytype/AnyType)): * ​val (`T`): ### `NVIDIASharedMemoryManager` `comptime NVIDIASharedMemoryManager = SharedMemoryManager[NVIDIASharedMemoryBasePtr]` ### `PipelineBarrier` `comptime PipelineBarrier[num_pipeline_stages: Int] = SMemArray[SharedMemBarrier, num_pipeline_stages]` Type alias for shared memory pipeline barrier array. #### Parameters * ​num\_pipeline\_stages ([`Int`](/std/builtin/int/Int)): ### `RegTile` `comptime RegTile[_dtype: DType, layout: Layout, /, *, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, AddressSpace.LOCAL), linear_idx_type: DType = _get_index_type(layout, AddressSpace.LOCAL), masked: Bool = False, alignment: Int = align_of[_dtype]()] = LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for register (local memory) tile tensors. #### Parameters * ​\_dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​element\_layout ([`Layout`](/kernels/layout/layout/Layout)): * ​layout\_int\_type ([`DType`](/std/builtin/dtype/DType)): * ​linear\_idx\_type ([`DType`](/std/builtin/dtype/DType)): * ​masked ([`Bool`](/std/builtin/bool/Bool)): * ​alignment ([`Int`](/std/builtin/int/Int)): ### `SMemBarrier` `comptime SMemBarrier = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` Type alias for shared memory barrier pointer. ### `SMemPtr` `comptime SMemPtr[type: AnyType] = LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]` #### Parameters * ​type ([`AnyType`](/std/builtin/anytype/AnyType)): ### `SMemTile` `comptime SMemTile[_dtype: DType, layout: Layout, /, *, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, AddressSpace.SHARED), linear_idx_type: DType = _get_index_type(layout, AddressSpace.SHARED), masked: Bool = False, alignment: Int = align_of[_dtype]()] = LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for shared memory tile tensors. #### Parameters * ​\_dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​element\_layout ([`Layout`](/kernels/layout/layout/Layout)): * ​layout\_int\_type ([`DType`](/std/builtin/dtype/DType)): * ​linear\_idx\_type ([`DType`](/std/builtin/dtype/DType)): * ​masked ([`Bool`](/std/builtin/bool/Bool)): * ​alignment ([`Int`](/std/builtin/int/Int)): ### `SMemTileIter` `comptime SMemTileIter[dtype: DType, layout: Layout] = LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`IteratorScatterGatherAmd`](./IteratorScatterGatherAmd): Iterator-based AMD scatter-gather for DRAM-register data movement. * [​`NVIDIASharedMemoryBasePtr`](./NVIDIASharedMemoryBasePtr): * [​`ScatterGatherAmd`](./ScatterGatherAmd): AMD tile-based scatter-gather for DRAM-register data movement. * [​`SharedMemoryManager`](./SharedMemoryManager): * [​`SMemArray`](./SMemArray): Shared memory array of fixed size. * [​`SMemTileArray`](./SMemTileArray): Array of tiles in shared memory. ## Traits * [​`SharedMemoryBasePtr`](./SharedMemoryBasePtr):
--- ## transpose
The module implements Transpose functions. ## `comptime` values ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Functions * [​`transpose`](./transpose): Permute the axis of `input` based on `perms`, and place the result in `output`. * [​`transpose_2d`](./transpose_2d): * [​`transpose_3d_swap_inner`](./transpose_3d_swap_inner): * [​`transpose_3d_swap_outer`](./transpose_3d_swap_outer): * [​`transpose_4d_swap_middle`](./transpose_4d_swap_middle): * [​`transpose_inplace`](./transpose_inplace): * [​`transpose_strided`](./transpose_strided): * [​`transpose_trivial_memcpy`](./transpose_trivial_memcpy):
--- ## transpose (Transpose)
`transpose[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.int]])` Permute the axis of `input` based on `perms`, and place the result in `output`. Example: ```mojo transpose(output, input, [2, 0, 1]) # guarantees output[x, y, z] = input[z, x, y] ``` **Parameters:** * ​rank ([`Int`](/mojo/std/builtin/int/Int)): The rank of input and output buffers. * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of buffer elements. **Args:** * ​output ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The output buffer. * ​input ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The input buffer. * ​perms ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Permutation of the input axes.
--- ## transpose_2d
`transpose_2d[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape], perms: LegacyUnsafePointer[Scalar[DType.int]], simplified_input_shape: IndexList[rank], simplified_rank: Int, offset: Int)`
--- ## transpose_3d_swap_inner
`transpose_3d_swap_inner[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.int]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_3d_swap_outer
`transpose_3d_swap_outer[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape], perms: LegacyUnsafePointer[Scalar[DType.int]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_4d_swap_middle
`transpose_4d_swap_middle[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape, strides], perms: LegacyUnsafePointer[Scalar[DType.int]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_inplace
`transpose_inplace[rows: Int, cols: Int, dtype: DType](buf: NDBuffer[dtype, 2, origin, DimList.__init__[Int, Int](rows, cols)])` `transpose_inplace[rows: Int, cols: Int, dtype: DType](buf: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## transpose_strided
`transpose_strided[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.int]])`
--- ## transpose_trivial_memcpy
`transpose_trivial_memcpy[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape])`
--- ## GemmShape
`@register_passable(trivial)` `struct GemmShape` Helper class to unpack gemm dimension and layout. ## Fields * ​M (`Int`): * ​N (`Int`): * ​K (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(index: IndexList[3]) -> Self` Constructor of a gemm shape record from a index tuple. **Args:** * ​index ([`IndexList`](/mojo/std/utils/index_/IndexList)): The int tuple containing the index(m,n,k). ### `__getitem__` `__getitem__(self, idx: Int) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `__setitem__` `__setitem__(mut self, idx: Int, value: Int)` ### `__add__` `__add__(self, rhs: Self) -> Self` Coordinate-wise addition of two gemm shape records. **Args:** * ​rhs (`Self`): Another gemm shape record to add with. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Coordinate-wise subtraction of two gemm shape records. **Args:** * ​rhs (`Self`): Another gemm shape record to subtract with. ### `get` `static get[transpose_b: Bool](c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive]) -> Self` Constructor of a gemm shape record from input buffers. M, N, and K are intentionally calculated using `a` and `c` ONLY. This is because `b` may be padded to a multiple of the tile size if it has been pre-packed. **Args:** * ​c ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer with allocated output space. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer containing matrix operand A. * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer containing matrix operand B. `static get[transpose_b: Bool, layout_c: Layout, layout_a: Layout, layout_b: Layout](c: LayoutTensor[dtype, layout_c, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout_a, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout_b, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Self` Constructor of a gemm shape record from input buffers. M, N, and K are intentionally calculated using `a` and `c` ONLY. This is because `b` may be padded to a multiple of the tile size if it has been pre-packed. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor with allocated output space. * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor containing matrix operand A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor containing matrix operand B. ### `as_index` `as_index(self) -> IndexList[3]` Utility to convert the underlying data to an index tuple. So that the utilities such as elementwise add can be used. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The constructed index tuple.
--- ## InnerKernelID
`@register_passable(trivial)` `struct InnerKernelID` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DEFAULT` `comptime DEFAULT = InnerKernelID(0)` ### `I8MM` `comptime I8MM = InnerKernelID(3)` ### `NEON` `comptime NEON = InnerKernelID(2)` ### `VNNI` `comptime VNNI = InnerKernelID(1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## KernelConfig (Utils)
`struct KernelConfig` Static configuration of the matmul inner kernel. ## Fields * ​kernel\_rows (`Int`): * ​kernel\_cols (`Int`): * ​simd\_size (`Int`): * ​packed\_shape (`DimList`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, *, kernel_rows: Int, kernel_cols: Int, simd_size: Int, packed_shape: DimList)`
--- ## MicroKernelShape
`@register_passable(trivial)` `struct MicroKernelShape` Record describing the inner kernel shape. ## Fields * ​simd\_rows (`Int`): * ​simd\_cols (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(rows: Int, cols: Int) -> Self`
--- ## SubMatmulConfig
`struct SubMatmulConfig` Static configuration of sub-matrices in parallel matmul. ## Fields * ​offset (`IndexList[3]`): * ​shape (`IndexList[3]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## apply_epilogue
`apply_epilogue[elementwise_lambda: elementwise_epilogue_type, dst_layout: Layout, dst_element_layout: Layout = Layout(IntTuple(1), IntTuple(1))](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: Int)`
--- ## calculate_tile_n_k
`calculate_tile_n_k[a_type: DType, b_type: DType, c_type: DType, kernel_cols: Int](n: Int, k: Int) -> IndexList[2]` Helper heuristic function to decide on tile size to partition the matmul given the cache size and desired data layout. **Parameters:** * ​a\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the A tensor. * ​b\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the B tensor. * ​c\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The dtype of the C tensor. * ​kernel\_cols ([`Int`](/mojo/std/builtin/int/Int)): The umber of columns of the micro kernel. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The calculated tile size to partition the matmul as (TileN, TileK). `calculate_tile_n_k[a_type: DType, b_type: DType, c_type: DType, kernel_cols: Int](global_tile_shape: GemmShape) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## dispatch_get_kernel_type
`dispatch_get_kernel_type[func: fn[x: Bool]() raises capturing -> None](m: Int, n: Int, k: Int)` `dispatch_get_kernel_type[func: fn[x: Bool]() capturing -> None](m: Int, n: Int, k: Int)`
--- ## get_kernel_config
`get_kernel_config[a_type: DType, b_type: DType, c_type: DType, *, kernel_type: Bool = False]() -> KernelConfig` Utility function to extract matmul configuration parameters for exported Functions. TODO: Add target dependent configuration parameters. **Returns:** [`KernelConfig`](/mojo/kernels/linalg/utils/KernelConfig)
--- ## get_kernel_type
`get_kernel_type(m: Int, n: Int, k: Int) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## get_matmul_arch_factor
`get_matmul_arch_factor[use_vnni: Bool, use_i8mm: Bool]() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_matmul_kernel_shape
`get_matmul_kernel_shape[a_type: DType, b_type: DType, c_type: DType, kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_kernel_shape_ARM
`get_matmul_kernel_shape_ARM[a_type: DType, b_type: DType, c_type: DType, kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_kernel_shape_x86
`get_matmul_kernel_shape_x86[kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_num_tasks
`get_matmul_num_tasks[a_type: DType, b_type: DType, c_type: DType, simd_size: Int, kernel_type: Bool](m: Int, n: Int, k: Int, max_num_tasks: Int) -> Int` Compute the number of tasks for parallel matmul. The max number of tasks is typically the number of threads/cores. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_matmul_prefetch_b_distance_k
`get_matmul_prefetch_b_distance_k() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_min_task_size
`get_min_task_size() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_packB_unroll_factor
`get_packB_unroll_factor() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_pack_data_size
`get_pack_data_size[dtype: DType]() -> Int` Utility to compute the number of elements to pack in each tile. Returns: The number of elements to pack. **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## get_partitioned_matmul
`get_partitioned_matmul[a_type: DType, b_type: DType, c_type: DType, kernel_rows: Int, kernel_cols: Int](m: Int, n: Int, k: Int, task_id: Int, num_tasks: Int) -> SubMatmulConfig` **Returns:** `SubMatmulConfig`
--- ## get_partitioned_matmul_mojo
`get_partitioned_matmul_mojo[b_type: DType, kernel_rows: Int, kernel_cols: Int, use_i8mm: Bool = False](m: Int, n: Int, k: Int, task_id: Int, num_tasks: Int) -> SubMatmulConfig` **Returns:** `SubMatmulConfig`
--- ## get_partitioned_matmul_mojo_shape
`get_partitioned_matmul_mojo_shape[b_type: DType, kernel_rows: Int, kernel_cols: Int, use_i8mm: Bool](m: Int, n: Int, k: Int, num_tasks: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## utils
## `comptime` values ### `elementwise_compute_lambda_type` `comptime elementwise_compute_lambda_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]` ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`GemmShape`](./GemmShape): Helper class to unpack gemm dimension and layout. * [​`InnerKernelID`](./InnerKernelID): * [​`KernelConfig`](./KernelConfig): Static configuration of the matmul inner kernel. * [​`MicroKernelShape`](./MicroKernelShape): Record describing the inner kernel shape. * [​`SubMatmulConfig`](./SubMatmulConfig): Static configuration of sub-matrices in parallel matmul. ## Functions * [​`apply_epilogue`](./apply_epilogue): * [​`calculate_tile_n_k`](./calculate_tile_n_k): Helper heuristic function to decide on tile size to partition the matmul given the cache size and desired data layout. * [​`dispatch_get_kernel_type`](./dispatch_get_kernel_type): * [​`get_kernel_config`](./get_kernel_config): Utility function to extract matmul configuration parameters for exported Functions. TODO: Add target dependent configuration parameters. * [​`get_kernel_type`](./get_kernel_type): * [​`get_matmul_arch_factor`](./get_matmul_arch_factor): * [​`get_matmul_kernel_shape`](./get_matmul_kernel_shape): * [​`get_matmul_kernel_shape_ARM`](./get_matmul_kernel_shape_ARM): * [​`get_matmul_kernel_shape_x86`](./get_matmul_kernel_shape_x86): * [​`get_matmul_num_tasks`](./get_matmul_num_tasks): Compute the number of tasks for parallel matmul. The max number of tasks is typically the number of threads/cores. * [​`get_matmul_prefetch_b_distance_k`](./get_matmul_prefetch_b_distance_k): * [​`get_min_task_size`](./get_min_task_size): * [​`get_pack_data_size`](./get_pack_data_size): Utility to compute the number of elements to pack in each tile. Returns: The number of elements to pack. * [​`get_packB_unroll_factor`](./get_packB_unroll_factor): * [​`get_partitioned_matmul`](./get_partitioned_matmul): * [​`get_partitioned_matmul_mojo`](./get_partitioned_matmul_mojo): * [​`get_partitioned_matmul_mojo_shape`](./get_partitioned_matmul_mojo_shape): * [​`packA_i8mm`](./packA_i8mm): * [​`partition_work`](./partition_work): * [​`select_inner_kernel`](./select_inner_kernel): * [​`use_i8mm_fn`](./use_i8mm_fn): * [​`use_vnni_fn`](./use_vnni_fn):
--- ## packA_i8mm
`packA_i8mm[a_type: DType](t0: Int, t1: Int, k: Int, a_ptr: LegacyUnsafePointer[Scalar[a_type]], a_packed_ptr: LegacyUnsafePointer[Scalar[a_type]])`
--- ## partition_work
`partition_work(task_id: Int, num_tasks: Int, work: Int, work_block_size: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## select_inner_kernel
`select_inner_kernel[a_type: DType, b_type: DType, c_type: DType]() -> InnerKernelID` **Returns:** `InnerKernelID`
--- ## use_i8mm_fn
`use_i8mm_fn[a_type: DType, b_type: DType, c_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## use_vnni_fn
`use_vnni_fn[a_type: DType, b_type: DType, c_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool)
--- ## MatmulConfig (Utils_gpu)
`@register_passable(trivial)` `struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]` Static configuration of GPU matmul. ## Fields * ​block\_tile\_shape (`IndexList[3]`): * ​warp\_tile\_shape (`IndexList[3]`): * ​mma\_shape (`IndexList[3]`): * ​num\_pipeline\_stages (`UInt`): * ​num\_k\_partitions (`UInt`): * ​k\_group\_size (`UInt`): * ​num\_warp\_k\_partitions (`UInt`): * ​cluster\_shape (`IndexList[3]`): * ​num\_consumer (`UInt`): * ​partitioned\_multicast (`Bool`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`Stringable`](/mojo/std/builtin/str/Stringable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable), [`Writable`](/mojo/std/format/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ACCUM_PRECISION` `comptime ACCUM_PRECISION = 1` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `OUTPUT_PRECISION` `comptime OUTPUT_PRECISION = 2` ### `split_k_reduction_scheme` `comptime split_k_reduction_scheme = env_get_int["SPLITK_REDUCTION_SCHEME", 2]()` ### `split_k_reduction_type` `comptime split_k_reduction_type = c_type if (2 == MatmulConfig[a_type, b_type, c_type, transpose_b].split_k_reduction_scheme)._mlir_value else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type` ## Methods ### `__init__` `__init__(*, block_tile_shape: IndexList[3] = Index(128, 128, 32), warp_tile_shape: IndexList[3] = Index(64, 64, 32), mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(1, 1, 1), num_pipeline_stages: Scalar[DType.uint] = 4, num_k_partitions: Scalar[DType.uint] = 1, k_group_size: Scalar[DType.uint] = 1, num_warp_k_partitions: Scalar[DType.uint] = 1, num_consumer: Scalar[DType.uint] = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self` ### `__eq__` `__eq__(self, rhs: MatmulConfig[a_type, b_type, c_type, transpose_b]) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `copy_field` `copy_field(mut self, other: MatmulConfig[a_type, b_type, c_type, transpose_b])` ### `swapAB` `swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig) ### `num_warps_m` `num_warps_m(self) -> UInt` **Returns:** `UInt` ### `num_warps_n` `num_warps_n(self) -> UInt` **Returns:** `UInt` ### `num_threads` `num_threads(self) -> UInt` **Returns:** `UInt` ### `shared_mem_usage` `shared_mem_usage(self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `grid_dim` `grid_dim(self, m: Scalar[DType.uint], n: Scalar[DType.uint]) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `block_dim` `block_dim(self) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `work_space_size` `work_space_size(self, M: Scalar[DType.uint], N: Scalar[DType.uint]) -> UInt` **Returns:** `UInt` ### `pdl_level` `pdl_level(self) -> PDLLevel` **Returns:** [`PDLLevel`](/mojo/std/gpu/primitives/grid_controls/PDLLevel) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/std/collections/string/string/String) ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/std/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## MatmulKernels
`@register_passable(trivial)` `struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]` Supported matmul kernels. The configurations are named as: **. BK, mma shape, and warp tile shape are decided internally. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable), [`TrivialRegisterPassable`](/mojo/std/builtin/value/TrivialRegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ampere_128x128_4` `comptime ampere_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `ampere_256x128_3` `comptime ampere_256x128_3 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 256, (2 * _bk_base[a_type]())), Index(64, 64, (2 * _bk_base[a_type]())), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 3, 1, 1, 1, 1, False, PDLLevel())` ### `ampere_256x64_4` `comptime ampere_256x64_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(64, 256, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `hopper_128x128_4` `comptime hopper_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `tuning_config` `comptime tuning_config = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(env_get_int["TUNE_BM", 128](), env_get_int["TUNE_BN", 128](), env_get_int["TUNE_BK", 32]()), Index(env_get_int["TUNE_WM", 64](), env_get_int["TUNE_WN", 64](), env_get_int["TUNE_BK", 32]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), SIMD[DType.uint, 1](env_get_int["TUNE_NUM_STAGES", 4]()), SIMD[DType.uint, 1](env_get_int["TUNE_NUM_K_PARTITIONS", 1]()), 1, SIMD[DType.uint, 1](env_get_int["TUNE_NUM_WARP_K_PARTITIONS", 1]()), 1, False, PDLLevel())`
--- ## block_swizzle
`block_swizzle(block_idx: IndexList[2, element_type=element_type], grid_dim: IndexList[2, element_type=element_type]) -> IndexList[2, element_type=element_type]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## create_hilbert_lut
`create_hilbert_lut(ctx: DeviceContext, grid_x: Int, grid_y: Int) -> DeviceBuffer[DType.uint32]` Precompute Hilbert-curve block swizzle lookup-table for a rectangular grid. The returned device pointer refers to a 1-D UInt32 array of length grid\_x \* grid\_y. For linear (row-major) block id `id`, the packed value at `lut[id]` encodes the swizzled coordinates: upper 16-bits = y, lower 16-bits = x. **Returns:** `DeviceBuffer`
--- ## get_hilbert_lut_with_cache
`get_hilbert_lut_with_cache(ctx: DeviceContext, grid_x: Int, grid_y: Int) -> DeviceBuffer[DType.uint32]` Get Hilbert lookup table using global cache (no struct needed). **Returns:** `DeviceBuffer`
--- ## utils_gpu
## `comptime` values ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`MatmulConfig`](./MatmulConfig): Static configuration of GPU matmul. * [​`MatmulKernels`](./MatmulKernels): Supported matmul kernels. ## Functions * [​`block_swizzle`](./block_swizzle): * [​`create_hilbert_lut`](./create_hilbert_lut): Precompute Hilbert-curve block swizzle lookup-table for a rectangular grid. * [​`get_hilbert_lut_with_cache`](./get_hilbert_lut_with_cache): Get Hilbert lookup table using global cache (no struct needed). * [​`select_config`](./select_config):
--- ## select_config
`select_config[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False](M: Int, N: Int, K: Int, ctx: DeviceContext) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## elu
`elu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Elu Op using the equation $z if z >= 0 else alpha*(e^z -1)$. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to compute the ELU operation on. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): The result of the ELU operation.
--- ## activations
The module contains implementations of activation functions. ## Functions * [​`elu`](./elu): Compute the Elu Op using the equation $z if z >= 0 else alpha*(e^z -1)$. * [​`leaky_relu`](./leaky_relu): Compute the Leaky ReLU using the equation $max(x, 0) + negative_slope * min(x, 0)$. * [​`relu`](./relu): Compute the Relu Op using the equation $max(x, 0)$. * [​`relu_n1`](./relu_n1): Compute the Relu N1 Op using the equation $max(min(x,1),-1)$. * [​`sign`](./sign): Compute the sign (0, 1) of the input value.
--- ## leaky_relu
`leaky_relu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width], negative_slope: Scalar[dtype]) -> SIMD[dtype, simd_width]` Compute the Leaky ReLU using the equation $max(x, 0) + negative_slope * min(x, 0)$. **Constraints:** Type must be a floating point Dtype. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to compute the Leaky ReLU operation on. * ​negative\_slope ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The slope for negative values. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): The result of the Leaky ReLU operation.
--- ## relu
`relu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Relu Op using the equation $max(x, 0)$. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to compute the RELU operation on. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): The result of the RELU operation.
--- ## relu_n1
`relu_n1[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Relu N1 Op using the equation $max(min(x,1),-1)$. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to compute the RELU N1 operation on. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): The result of the RELU N1 operation.
--- ## sign
`sign[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the sign (0, 1) of the input value. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/std/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): The value to compute the sign operation on. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): The result of the sign operation.
--- ## arange
`arange[dtype: DType, simd_width: Int](start: Scalar[dtype], stop: Scalar[dtype], step: Scalar[dtype], index: IndexList[1]) -> SIMD[dtype, simd_width]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## arange_shape
`arange_shape[dtype: DType, single_thread_blocking_override: Bool](start: Scalar[dtype], stop: Scalar[dtype], step: Scalar[dtype]) -> IndexList[1]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## arange (Arange)
## Functions * [​`arange`](./arange): * [​`arange_shape`](./arange_shape):
--- ## arg_nonzero
`arg_nonzero[dtype: DType, output_type: DType](input_buffer: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output_buffer: TileTensor[output_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Gather the indices of all non-zero elements in input buffer storing the indices in the output\_buffer. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The element dtype. * ​output\_type ([`DType`](/mojo/std/builtin/dtype/DType)): The integer dtype to store the indices in. **Args:** * ​input\_buffer ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The tensor to count the non-zeros in. * ​output\_buffer ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The indices of all non-zero elements.
--- ## arg_nonzero_shape
`arg_nonzero_shape[dtype: DType, single_thread_blocking_override: Bool](input_buffer: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]) -> IndexList[2]` Return \[NumNonZeros, InputRank] where NumNonZeros are the number of non-zero elements in the input. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): The element dtype. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/std/builtin/bool/Bool)): This op can block. **Args:** * ​input\_buffer ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The tensor to count the non-zeros in. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): Shape of the arg\_nonzero kernel for this input \[NumNonZeros, InputRank].
--- ## arg_nonzero (Arg_nonzero)
## Functions * [​`arg_nonzero`](./arg_nonzero): Gather the indices of all non-zero elements in input buffer storing the indices in the output\_buffer. * [​`arg_nonzero_shape`](./arg_nonzero_shape): Return \[NumNonZeros, InputRank] where NumNonZeros are the number of non-zero elements in the input.
--- ## argmax
`argmax(input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], axis: Int, output: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Finds the indices of the maximum element along the specified axis. **Args:** * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input tensor. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis. * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The output tensor. `argmax(input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], axis_buf: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Finds the indices of the maximum element along the specified axis. **Args:** * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input tensor. * ​axis\_buf ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The axis tensor. * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The axis tensor.
--- ## argmin
`argmin(input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], axis: Int, output: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Finds the indices of the minimum element along the specified axis. **Args:** * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input tensor. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis. * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The output tensor. `argmin(input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], axis_buf: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Finds the indices of the minimum element along the specified axis. **Args:** * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input tensor. * ​axis\_buf ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The axis tensor. * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The axis tensor.
--- ## argmaxmin
## Functions * [​`argmax`](./argmax): Finds the indices of the maximum element along the specified axis. * [​`argmin`](./argmin): Finds the indices of the minimum element along the specified axis.
--- ## argmax_gpu
`argmax_gpu[dtype: DType, output_type: DType](ctx: DeviceContext, input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[output_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])`
--- ## argmaxmin_gpu
`argmaxmin_gpu[dtype: DType, output_type: DType, largest: Bool](ctx: DeviceContext, input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[output_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Wraps the Top-K GPU kernel with K=1 to perform argmax on the inner-most dimension. **Parameters:** * ​dtype ([`DType`](/mojo/std/builtin/dtype/DType)): DType - The data dtype of the input tensor. * ​output\_type ([`DType`](/mojo/std/builtin/dtype/DType)): DType - The data dtype of the output tensor. * ​largest ([`Bool`](/mojo/std/builtin/bool/Bool)): Bool - Whether to perform argmax or argmin.
--- ## argmin_gpu
`argmin_gpu[dtype: DType, output_type: DType](ctx: DeviceContext, input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[output_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])`
--- ## argmaxmin_gpu (Argmaxmin_gpu)
## Functions * [​`argmax_gpu`](./argmax_gpu): * [​`argmaxmin_gpu`](./argmaxmin_gpu): Wraps the Top-K GPU kernel with K=1 to perform argmax on the inner-most dimension. * [​`argmin_gpu`](./argmin_gpu):
--- ## argsort
`argsort[*, ascending: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)` Performs argsort on input buffer, storing indices in output buffer. **Parameters:** * ​ascending ([`Bool`](/mojo/std/builtin/bool/Bool)): Sort direction (True for ascending, False for descending). * ​target (`StringSlice`): Target device ("cpu" or "gpu"). **Args:** * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Buffer to store sorted indices. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Buffer containing values to sort. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for execution. `argsort[ascending: Bool = True](output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` CPU-only version of argsort. **Parameters:** * ​ascending ([`Bool`](/mojo/std/builtin/bool/Bool)): Sort direction (True for ascending, False for descending). **Args:** * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Buffer to store sorted indices. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Buffer containing values to sort.
--- ## argsort (Argsort)
## Functions * [​`argsort`](./argsort): Performs argsort on input buffer, storing indices in output buffer.
--- ## Attention (3)
`struct Attention[attention_config_t: AttentionConfig, output_type: DType, q_type: DType, k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, dtype: DType, //, config: MHAConfig[dtype], group: Int, token_gen: Bool, sink: Bool, q_depth: Int = Int.__init__[UInt](config.depth), cache_depth: Int = Int.__init__[UInt](config.depth), output_depth: Int = Int.__init__[UInt](config.depth)]` ## Fields * ​out\_reg\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].OutputRegisterBufferType`): * ​p\_reg\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].PRegisterBufferType`): * ​gmem\_manager (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].GlobalMemoryManagerType`): * ​smem\_manager (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].SharedMemoryManagerType`): * ​q\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].QRegisterBufferType`): * ​output\_ptr (`UnsafePointer[Scalar[output_type], MutAnyOrigin]`): * ​batch\_idx (`Int`): * ​k (`k_t`): * ​v (`v_t`): * ​mask (`mask_t`): * ​mask\_block\_row (`UInt32`): * ​mask\_warp\_row (`UInt32`): * ​mask\_warp\_col (`UInt32`): * ​kv\_start\_row (`UInt32`): * ​scale (`Scalar[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type]`): * ​seq\_len (`Int`): * ​num\_keys (`Int`): * ​start\_pos (`Int`): * ​cache\_start\_pos (`Int`): * ​softmax (`Softmax[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Layout.row_major(Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_n_mmas)), Layout.row_major(Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_m), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_n)), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].warp_layout, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].fragment_layout, True]`): * ​warp\_scratch\_tensor (`LayoutTensor[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Layout.row_major((2 * Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_n)), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM)), MutAnyOrigin, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial` ### `accum_type` `comptime accum_type = get_accum_type[q_type]()` ### `BK` `comptime BK = config.block_k[dtype]()` ### `BM` `comptime BM = config.block_m[dtype]()` ### `BN` `comptime BN = config.block_n[dtype]()` ### `depth` `comptime depth = config.depth` ### `fragment_layout` `comptime fragment_layout = get_fragment_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `fragment_layout_nested` `comptime fragment_layout_nested = get_nested_fragment_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `GlobalMemoryManagerType` `comptime GlobalMemoryManagerType = GlobalMemoryManager[q_type, SIMD.__init__[DType.uint32, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM), SIMD.__init__[DType.uint32, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), SIMD.__init__[DType.uint32, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), SIMD.__init__[DType.uint32, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth), SIMD.__init__[DType.uint32, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_heads), group, token_gen, q_depth, output_depth]` ### `k_group_size` `comptime k_group_size = (16 // (num_matrix_reg[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](0), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](2)]() * size_of[q_type]()))` ### `mma_shape` `comptime mma_shape = attention_config_t.get_mma_shape()()` ### `num_heads` `comptime num_heads = config.num_heads` ### `num_k_mmas2` `comptime num_k_mmas2 = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK, SIMD[DType.uint, 1]((Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](2) * Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size)))` ### `num_m_mmas` `comptime num_m_mmas = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM, SIMD[DType.uint, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](0)))` ### `num_n_mmas` `comptime num_n_mmas = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN, SIMD[DType.uint, 1](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](1)))` ### `num_n_mmas_output` `comptime num_n_mmas_output = ceildiv((output_depth // Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_n)), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](1))` ### `num_stages` `comptime num_stages = 2` ### `num_threads` `comptime num_threads = config.num_threads[dtype]()` ### `num_warps_m` `comptime num_warps_m = (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM)` ### `num_warps_n` `comptime num_warps_n = (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN)` ### `output_frag_size` `comptime output_frag_size = Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].fragment_layout.size()` ### `OutputRegisterBufferType` `comptime OutputRegisterBufferType = OutputRegisterBuffer[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_n_mmas_output, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].output_frag_size]` ### `PRegisterBufferType` `comptime PRegisterBufferType = PRegisterBuffer[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, q_type, Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_n_mmas), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].output_frag_size, (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN != Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, attention_config_t.double_buffer, 2 if attention_config_t.double_buffer else 1]` ### `QRegisterBufferType` `comptime QRegisterBufferType = QRegisterBuffer[q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), q_depth, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].warp_layout]` ### `SharedMemoryManagerType` `comptime SharedMemoryManagerType = SharedMemoryManager[attention_config_t.shared_kv, attention_config_t.full_kv, attention_config_t.depth_padded, attention_config_t.double_buffer, q_type, Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), Int.__init__[UInt](Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth), token_gen]` ### `swap_a_b` `comptime swap_a_b = True` ### `use_exp2` `comptime use_exp2 = True` ### `warp_layout` `comptime warp_layout = get_warp_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `WM` `comptime WM = config.warp_m[dtype]()` ### `WN` `comptime WN = config.warp_n[dtype]()` ## Methods ### `__init__` `__init__(out self, attention_config: attention_config_t, output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], q: UnsafePointer[Scalar[q_type], MutAnyOrigin], k: k_t, v: v_t, mask: mask_t, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], batch_idx: Int, scale: Float32, seq_len: Int, num_keys: Int, start_pos: Int, cache_start_pos: Int = 0)` ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** `UInt` ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** `UInt` ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** `UInt` ### `zero_p_buffer` `zero_p_buffer[stage: Int = 0](self)` ### `get_batch_idx` `get_batch_idx(self) -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `scale_p_reg` `scale_p_reg[stage: Int = 0](self)` ### `get_tensor_core_mma_qk` `static get_tensor_core_mma_qk(out result: TiledTensorCore[get_accum_type[q_type](), q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, True])` **Returns:** [`TiledTensorCore`](/mojo/kernels/layout/tensor_core/TiledTensorCore) ### `get_tensor_core_mma_pv` `static get_tensor_core_mma_pv(out result: TiledTensorCore[get_accum_type[q_type](), q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size])` **Returns:** [`TiledTensorCore`](/mojo/kernels/layout/tensor_core/TiledTensorCore) ### `mma_qk` `mma_qk[k_buffer_type: KVBuffer, //, prefetch_function: OptionalReg[fn() capturing -> None] = None, beg_iter: Int = 0, num_iters: Int = Int.__init__[UInt]((Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK)), prefetched_b_tile: Bool = False](mut self, mut k_buffer: k_buffer_type)` ### `mma_pv` `mma_pv[v_buffer_type: KVBuffer, //, prefetch_function: OptionalReg[fn() capturing -> None] = None, prefetched_b_tile: Bool = True](mut self, mut v_buffer: v_buffer_type)` ### `mask_status` `mask_status(self, kv_tile_start_row: UInt32) -> TileMaskStatus` **Returns:** [`TileMaskStatus`](/mojo/kernels/nn/mha_mask/TileMaskStatus) ### `mask_advance` `mask_advance(mut self)` ### `mask_skip_tile` `mask_skip_tile(self, status: TileMaskStatus) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `mask_skip_and_advance` `mask_skip_and_advance(mut self, kv_tile_start_row: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `mask_apply` `mask_apply[stage: Int = 0](mut self, kv_tile_start_row: UInt32, kv_tile_num_rows: UInt32, not_last_iter: Bool)` ### `online_softmax` `online_softmax[stage: Int = 0](mut self)` ### `store_output` `store_output(self)` ### `copy_fragment_to_smem` `copy_fragment_to_smem(self)` ### `store_partition_info` `store_partition_info(self, num_partitions: Int, exp_sum_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin], qk_max_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin])`
--- ## AttentionConfig
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `depth_padded` `comptime depth_padded` ### `double_buffer` `comptime double_buffer` ### `full_kv` `comptime full_kv` ### `shared_kv` `comptime shared_kv` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** `UInt` ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** `UInt` ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** `UInt` ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## attention (4)
## Structs * [​`Attention`](./Attention): ## Traits * [​`AttentionConfig`](./AttentionConfig):
--- ## KBufferConfig
`struct KBufferConfig[BN: Int, BK: Int, WN: Int]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVBufferConfig`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBufferConfig) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `btile_dim0` `comptime btile_dim0 = BN` ### `btile_dim1` `comptime btile_dim1 = BK` ### `iterator_axis` `comptime iterator_axis = 1` ### `wsize` `comptime wsize = KBufferConfig[BN, BK, WN].wtile_dim0` ### `wtile_dim0` `comptime wtile_dim0 = WN` ### `wtile_dim1` `comptime wtile_dim1 = BK` ## Methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## KVBuffer
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `mma_tile_layout` `comptime mma_tile_layout` ## Required methods ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `load_from_dram` `load_from_dram(mut self: _Self)` ### `get_mma_tile` `get_mma_tile(self: _Self) -> LayoutTensor[_Self._dtype, _Self.mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self: _Self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self: _Self)`
--- ## KVBufferConfig
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `btile_dim0` `comptime btile_dim0` ### `btile_dim1` `comptime btile_dim1` ### `iterator_axis` `comptime iterator_axis` ### `wsize` `comptime wsize` ### `wtile_dim0` `comptime wtile_dim0` ### `wtile_dim1` `comptime wtile_dim1` ## Required methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## KVBufferImpl
`struct KVBufferImpl[dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, mut: Bool, origin: Origin[mut=mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]` ## Fields * ​load\_tile (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].LoadTileType`): * ​mma\_tile (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMATileType`): * ​smem\_iter (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SharedIterType`): * ​bounds (`Int`): * ​load\_tile\_id (`Int`): * ​global\_iterator (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].GlobalTiledIteratorType`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBuffer) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `base_layout` `comptime base_layout = Layout.row_major(config.btile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `GlobalTensorType` `comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `GlobalTiledIteratorType` `comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]` ### `LoadTileType` `comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((num_stages * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_k_tiles), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `MMA_K` `comptime MMA_K = shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_N` `comptime MMA_N = shape.__getitem__[3, DType.int64, Int](1)` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))` ### `num_mmas` `comptime num_mmas = ceildiv(config.wsize, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)` ### `num_repeats` `comptime num_repeats = (config.btile_dim1 // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `num_warps_n` `comptime num_warps_n = (BN // WN)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SharedIterType.LayoutTensorType` ### `SharedWarpTileType` `comptime SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[dtype](), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()]` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].base_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].tiler_layout, True) if token_gen.__invert__()._mlir_value else Layout.row_major(config.btile_dim0, config.btile_dim1)` ### `thread_layout` `comptime thread_layout = Layout.row_major(((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value()), (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value() // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) if token_gen else Layout.row_major((num_threads // 4), 4)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_repeats)` ### `wtile_dim0` `comptime wtile_dim0 = config.wtile_dim0` ### `wtile_dim1` `comptime wtile_dim1 = config.wtile_dim1` ## Methods ### `__init__` `__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: Optional[Int], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `load_from_dram` `load_from_dram(mut self)` ### `get_mma_tile` `get_mma_tile(self) -> KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMATileType` **Returns:** `KVBufferImpl` ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self)`
--- ## OutputRegisterBuffer
`struct OutputRegisterBuffer[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int]` ## Fields * ​reg\_tile (`OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].RegisterTileType`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `reg_dtype` `comptime reg_dtype = dtype` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ## Methods ### `__init__` `__init__(out self)` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `vectorize` `vectorize(self) -> LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[True, dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), 1, output_frag_size]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[True, dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), 1, output_frag_size]()[0], layout_int_type=_get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL)]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `apply_softmax_denominator` `apply_softmax_denominator(self, rowsum: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `zero` `zero(self)` ### `get_reg_tile` `get_reg_tile[stage: Int = 0](self) -> OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].RegisterTileType` **Returns:** `OutputRegisterBuffer`
--- ## PRegisterBuffer
`struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], k_group_size: Int, tr_load_enabled: Bool = False, num_stages: Int = 1]` ## Fields * ​reg\_tile (`PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].RegisterTileType_`): * ​shared\_memory\_tile (`PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].SharedMemoryTileType`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`RegisterMMABuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterMMABuffer) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_dtype` `comptime mma_dtype = dtype` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major(num_m_mmas, simd_width_of[dtype]())` ### `MMATileType` `comptime MMATileType = LayoutTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].mma_dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `reg_dtype` `comptime reg_dtype = accum_type_` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)` ### `reg_tile_layout_` `comptime reg_tile_layout_ = Layout.row_major(((num_stages * num_n_mmas) * num_m_mmas), output_frag_size)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `RegisterTileType_` `comptime RegisterTileType_ = LayoutTensor[accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout_, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `shared_memory_layout` `comptime shared_memory_layout = blocked_product(Layout.row_major(BM, BK), Layout.row_major(1, (BN // BK)), False)` ### `SharedMemoryTileType` `comptime SharedMemoryTileType = LayoutTensor[dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])` ### `get_mma_tile_reg` `get_mma_tile_reg[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MMATileType` **Returns:** `PRegisterBuffer` ### `get_mma_tile_shared` `get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MMATileType` **Returns:** `PRegisterBuffer` ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MMATileType` **Returns:** `PRegisterBuffer` `get_mma_tile[tile_idx: Int, k_idx: Int, stage: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MMATileType` **Returns:** `PRegisterBuffer` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `vectorize` `vectorize[stage: Int = 0](self) -> LayoutTensor[accum_type_, coalesce(LayoutTensor._compute_tile_layout[True, accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL), False, align_of[accum_type_](), 1, output_frag_size]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[True, accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL), False, align_of[accum_type_](), 1, output_frag_size]()[0], layout_int_type=_get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].reg_tile_layout, AddressSpace.LOCAL)]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `zero` `zero[stage: Int](self)` `zero(self)` ### `get_reg_tile` `get_reg_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].RegisterTileType` **Returns:** `PRegisterBuffer` ### `get_shared_memory_tile` `get_shared_memory_tile(self, tile_idx: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, AddressSpace.SHARED), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, AddressSpace.SHARED), False, align_of[dtype](), BM, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, AddressSpace.SHARED), masked=_tile_is_masked[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].shared_memory_layout, BM, BK]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared(self)`
--- ## QRegisterBuffer
`struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_layout: Layout]` ## Fields * ​reg\_tile (`QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].RegisterTileType`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`RegisterMMABuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterMMABuffer) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_dtype` `comptime mma_dtype = dtype` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `mma_tile_layout` `comptime mma_tile_layout = LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), axis=0]()[0]` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), tile_size=(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), axis=0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), axis=0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_K * k_group_size))` ### `num_mmas` `comptime num_mmas = ceildiv(WM, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_M)` ### `num_tiles` `comptime num_tiles = (depth // BK)` ### `reg_dtype` `comptime reg_dtype = dtype` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major(((QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles) * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `TiledIteratorType` `comptime TiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, axis=0, layout_int_type=_get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), masked=_tile_is_masked[QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()]` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `get_iter` `get_iter(self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].TiledIteratorType` **Returns:** `QRegisterBuffer` ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMATileType` **Returns:** `QRegisterBuffer` ### `get_reg_tile` `get_reg_tile[stage: Int = 0](self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].RegisterTileType` **Returns:** `QRegisterBuffer` ### `zero` `zero(self)`
--- ## RegisterBuffer
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType) ## `comptime` members ### `reg_dtype` `comptime reg_dtype` ### `reg_tile_layout` `comptime reg_tile_layout` ## Required methods ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `zero` `zero(self: _Self)` ### `get_reg_tile` `get_reg_tile[stage: Int = 0](self: _Self) -> LayoutTensor[_Self.reg_dtype, _Self.reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## RegisterMMABuffer
## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer) ## `comptime` members ### `mma_dtype` `comptime mma_dtype` ### `mma_tile_layout` `comptime mma_tile_layout` ### `reg_dtype` `comptime reg_dtype` ### `reg_tile_layout` `comptime reg_tile_layout` ## Required methods ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self: _Self) -> LayoutTensor[_Self.mma_dtype, _Self.mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `zero` `zero(self: _Self)` ### `get_reg_tile` `get_reg_tile[stage: Int = 0](self: _Self) -> LayoutTensor[_Self.reg_dtype, _Self.reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## VBufferConfig
`struct VBufferConfig[BN: Int, BK: Int, WN: Int, depth: Int]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVBufferConfig`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBufferConfig) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `btile_dim0` `comptime btile_dim0 = BK` ### `btile_dim1` `comptime btile_dim1 = depth` ### `iterator_axis` `comptime iterator_axis = 0` ### `wsize` `comptime wsize = VBufferConfig[BN, BK, WN, depth].wtile_dim1` ### `wtile_dim0` `comptime wtile_dim0 = BK` ### `wtile_dim1` `comptime wtile_dim1 = (depth // (BN // WN))` ## Methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## VBufferTransposeLoads
`struct VBufferTransposeLoads[dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, mut: Bool, origin: Origin[mut=mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]` ## Fields * ​load\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].LoadTileType`): * ​mma\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMATileType`): * ​smem\_iter (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].SharedIterType`): * ​global\_iterator (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GlobalTiledIteratorType`): * ​global\_base\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GlobalTensorType`): * ​current\_stage (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`KVBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBuffer) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `base_layout` `comptime base_layout = Layout.row_major(VBufferTransposeLoads.pad[dtype, layout, address_space, alignment, mut, origin, masked, layout_int_type, linear_idx_type, out_type, in_type, shape, group_size, transpose_b, tensor_core_mma, BN, BK, depth, num_threads, num_stages, depth](), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `depth_tile_size` `comptime depth_tile_size = min(depth, 128)` ### `GlobalTensorType` `comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `GlobalTiledIteratorType` `comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, BK, depth]()[0], origin, address_space=address_space, axis=0, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BK, depth]()]` ### `load_width` `comptime load_width = 4 if (depth == 64)._mlir_value else VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width` ### `loads_per_thread_per_depth_tile` `comptime loads_per_thread_per_depth_tile = ((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size * BK) // (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width * num_threads))` ### `LoadTileType` `comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].loads_per_thread_per_depth_tile * (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size)) * num_stages), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `MMA_K` `comptime MMA_K = shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = shape.__getitem__[3, DType.int64, Int](0)` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major((depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_depth_tiles` `comptime num_depth_tiles = (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M)` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_K * group_size))` ### `num_repeats` `comptime num_repeats = (BK // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].SharedIterType.LayoutTensorType` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].base_layout, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].tiler_layout, True)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].num_repeats)` ## Methods ### `__init__` `__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/std/builtin/dtype/DType) ### `pad` `static pad[dim: Int]() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int) ### `load_from_dram` `load_from_dram(mut self)` ### `get_mma_tile` `get_mma_tile(self) -> VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMATileType` **Returns:** `VBufferTransposeLoads` ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self)`
--- ## buffers
## `comptime` values ### `KBuffer` `comptime KBuffer[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = KVBufferImpl[KBufferConfig[BN, BK, WN], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]` #### Parameters * ​out\_type ([`DType`](/std/builtin/dtype/DType)): * ​in\_type ([`DType`](/std/builtin/dtype/DType)): * ​shape ([`IndexList`](/std/utils/index_/IndexList)): * ​group\_size ([`Int`](/std/builtin/int/Int)): * ​transpose\_b ([`Bool`](/std/builtin/bool/Bool)): * ​tensor\_core\_mma ([`TiledTensorCore`](/kernels/layout/tensor_core/TiledTensorCore)): * ​swizzle ([`Optional`](/std/collections/optional/Optional)): * ​BN ([`Int`](/std/builtin/int/Int)): * ​WN ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): * ​depth ([`Int`](/std/builtin/int/Int)): * ​num\_threads ([`Int`](/std/builtin/int/Int)): * ​num\_stages ([`Int`](/std/builtin/int/Int)): * ​token\_gen ([`Bool`](/std/builtin/bool/Bool)): ### `VBuffer` `comptime VBuffer[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = KVBufferImpl[VBufferConfig[BN, BK, WN, depth], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]` #### Parameters * ​out\_type ([`DType`](/std/builtin/dtype/DType)): * ​in\_type ([`DType`](/std/builtin/dtype/DType)): * ​shape ([`IndexList`](/std/utils/index_/IndexList)): * ​group\_size ([`Int`](/std/builtin/int/Int)): * ​transpose\_b ([`Bool`](/std/builtin/bool/Bool)): * ​tensor\_core\_mma ([`TiledTensorCore`](/kernels/layout/tensor_core/TiledTensorCore)): * ​swizzle ([`Optional`](/std/collections/optional/Optional)): * ​BN ([`Int`](/std/builtin/int/Int)): * ​WN ([`Int`](/std/builtin/int/Int)): * ​BK ([`Int`](/std/builtin/int/Int)): * ​depth ([`Int`](/std/builtin/int/Int)): * ​num\_threads ([`Int`](/std/builtin/int/Int)): * ​num\_stages ([`Int`](/std/builtin/int/Int)): * ​token\_gen ([`Bool`](/std/builtin/bool/Bool)): ## Structs * [​`KBufferConfig`](./KBufferConfig): * [​`KVBufferImpl`](./KVBufferImpl): * [​`OutputRegisterBuffer`](./OutputRegisterBuffer): * [​`PRegisterBuffer`](./PRegisterBuffer): * [​`QRegisterBuffer`](./QRegisterBuffer): * [​`VBufferConfig`](./VBufferConfig): * [​`VBufferTransposeLoads`](./VBufferTransposeLoads): ## Traits * [​`KVBuffer`](./KVBuffer): * [​`KVBufferConfig`](./KVBufferConfig): * [​`RegisterBuffer`](./RegisterBuffer): * [​`RegisterMMABuffer`](./RegisterMMABuffer):
--- ## amd (Amd)
AMD GPU attention operations. ## Modules * [​`attention`](./attention/): * [​`buffers`](./buffers/): * [​`mha_gfx942`](./mha_gfx942/): * [​`mha_gfx950`](./mha_gfx950/): * [​`mla`](./mla/): * [​`mma`](./mma/): * [​`softmax`](./softmax/): * [​`utils`](./utils/):
--- ## MHAAttentionConfig
`struct MHAAttentionConfig[dtype: DType, //, token_gen: Bool, config: MHAConfig[dtype], group: Int]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`AttentionConfig`](/mojo/kernels/nn/attention/gpu/amd/attention/AttentionConfig), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `depth_padded` `comptime depth_padded = False if MHAAttentionConfig[token_gen, config, group].USE_EXPERIMENTAL_CDNA4_MHA_KERNEL else True` ### `double_buffer` `comptime double_buffer = True if MHAAttentionConfig[token_gen, config, group].USE_EXPERIMENTAL_CDNA4_MHA_KERNEL else False` ### `full_kv` `comptime full_kv = True if MHAAttentionConfig[token_gen, config, group].USE_EXPERIMENTAL_CDNA4_MHA_KERNEL else False` ### `shared_kv` `comptime shared_kv = False if MHAAttentionConfig[token_gen, config, group].USE_EXPERIMENTAL_CDNA4_MHA_KERNEL else True` ### `USE_EXPERIMENTAL_CDNA4_MHA_KERNEL` `comptime USE_EXPERIMENTAL_CDNA4_MHA_KERNEL = token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer()` ## Methods ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** `UInt` ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** `UInt` ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** `UInt` ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32)
--- ## mha_gfx942
## Structs * [​`MHAAttentionConfig`](./MHAAttentionConfig):
--- ## KVBuffer (Mha_gfx950)
`struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]` ## Fields * ​mma\_tile (`KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMATileType`): * ​smem\_iter (`KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].SharedIterType`): * ​kv\_cache\_iter (`KVCacheIterator[kv_t, BN, kv_num_heads, depth]`): * ​lds\_base\_ptrs (`InlineArray[UInt32, 2]`): * ​warp\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial` ### `base_layout` `comptime base_layout = Layout.row_major(BN, BK)` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMATileType` `comptime MMATileType = LayoutTensor[kv_t.dtype, Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_mmas2` `comptime num_k_mmas2 = ceildiv(BK, (KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * k_group_size))` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)` ### `num_mmas` `comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)` ### `num_repeats` `comptime num_repeats = (depth // BK)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].SharedIterType.LayoutTensorType` ### `SharedWarpTileType` `comptime SharedWarpTileType = LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[kv_t.dtype](), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()]` ### `simd_width` `comptime simd_width = simd_width_of[kv_t.dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].base_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].tiler_layout, False)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_repeats)` ### `wtile_dim0` `comptime wtile_dim0 = WN` ### `wtile_dim1` `comptime wtile_dim1 = BK` ## Methods ### `__init__` `__init__(out self, k_cache: kv_t, batch_idx: Scalar[DType.uint], head_idx: Scalar[DType.uint], shared_ptr: UnsafePointer[Scalar[kv_t.dtype], origin, address_space=AddressSpace.SHARED], end: Scalar[DType.uint], warp_id: UInt32)` ### `load_from_dram` `load_from_dram[buffer_idx: Int](mut self)` ### `get_mma_tile` `get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), tile_size=(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), axis=0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), tile_size=(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), axis=0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), tile_size=(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), axis=0]()[0], AddressSpace.LOCAL), False, align_of[kv_t.dtype](), tile_size=(LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), tile_size=(Layout.row_major(((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), axis=0]()[0].shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), axis=0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared(self)` ### `load_from_shared` `load_from_shared(self, buffer: Scalar[DType.uint])` `load_from_shared[bk_tile: Int](self, buffer: Scalar[DType.uint])`
--- ## KVCacheIterator
`struct KVCacheIterator[cache_t: MHAOperand, tile_size: Int, kv_num_heads: Int, depth: Int]` ## Fields * ​cache (`cache_t`): * ​end (`Int`): * ​tile\_start\_row (`Int`): * ​batch\_idx (`Int`): * ​kv\_head\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial` ### `kv_gmem_layout` `comptime kv_gmem_layout = Layout(IntTuple(tile_size, depth), IntTuple((kv_num_heads * depth), 1))` ## Methods ### `__init__` `__init__(out self, cache: cache_t, batch_idx: Int, kv_head_idx: Int, end: Int)` ### `next_unsafe` `next_unsafe(mut self) -> LayoutTensor[cache_t.dtype, KVCacheIterator[cache_t, tile_size, kv_num_heads, depth].kv_gmem_layout, MutAnyOrigin, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `increment` `increment(mut self)`
--- ## barrier
`barrier[*, schedule_barrier_before: Bool = True, schedule_barrier_after: Bool = True]()`
--- ## block_sync_lds
`block_sync_lds[*, lgkmcnt: UInt32 = 0]()` Synchronize LDS (local data share) with waitcnt barrier.
--- ## block_sync_lds_direct_load
`block_sync_lds_direct_load[*, vmcnt: UInt32 = 0]()` Synchronize LDS for direct load with waitcnt barrier.
--- ## mha_gfx950
## Structs * [​`KVBuffer`](./KVBuffer): * [​`KVCacheIterator`](./KVCacheIterator): ## Functions * [​`barrier`](./barrier): * [​`block_sync_lds`](./block_sync_lds): Synchronize LDS (local data share) with waitcnt barrier. * [​`block_sync_lds_direct_load`](./block_sync_lds_direct_load): Synchronize LDS for direct load with waitcnt barrier. * [​`scheduling_hints_pv`](./scheduling_hints_pv): * [​`scheduling_hints_qk`](./scheduling_hints_qk): * [​`set_priority`](./set_priority):
--- ## scheduling_hints_pv
`scheduling_hints_pv[group: Int]()`
--- ## scheduling_hints_qk
`scheduling_hints_qk[group: Int]()`
--- ## set_priority
`set_priority[priority: Int]()`
--- ## MLAAttentionConfig
`struct MLAAttentionConfig[dtype: DType, //, token_gen: Bool, config: MHAConfig[dtype]]` ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`AttentionConfig`](/mojo/kernels/nn/attention/gpu/amd/attention/AttentionConfig), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `depth_padded` `comptime depth_padded = True` ### `double_buffer` `comptime double_buffer = False` ### `full_kv` `comptime full_kv = False` ### `shared_kv` `comptime shared_kv = True` ## Methods ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** `UInt` ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** `UInt` ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** `UInt` ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: Scalar[DType.uint]]() -> UInt32` **Returns:** [`UInt32`](/mojo/std/builtin/simd/#uint32)
--- ## mla
## Structs * [​`MLAAttentionConfig`](./MLAAttentionConfig):
--- ## mma (Mma)
## Functions * [​`mma`](./mma):
--- ## mma (3)
`mma[c_register_buffer_type: RegisterBuffer, a_register_buffer_type: RegisterMMABuffer, b_buffer_type: KVBuffer, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BK: Int, prefetch_function: OptionalReg[fn() capturing -> None], swap_a_b: Bool = False, beg_iter: Int = 0, num_iters: Int = 1, prefetched_b_tile: Bool = False](c: c_register_buffer_type, mut a_tile: a_register_buffer_type, mut b_tile: b_buffer_type)`
--- ## Softmax
`struct Softmax[dtype: DType, score_layout_by_mma_unit: Layout, block_layout_by_warp: Layout, warp_layout: Layout, fragment_layout: Layout, use_exp2: Bool = False]` ## Fields * ​rowmax\_tensor (`Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType`): * ​rowsum\_tensor (`Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowSumTensorType`): * ​score\_frag\_rowmax (`LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]`): * ​score\_frag\_rowsum (`LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]`): * ​correction (`LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `exp_function` `comptime exp_function = _exp2_concrete if use_exp2 else _exp_concrete` ### `frag_is_row_vector` `comptime frag_is_row_vector = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows == 1)` ### `frag_num_cols` `comptime frag_num_cols = fragment_layout.shape[1].value()` ### `frag_num_rows` `comptime frag_num_rows = fragment_layout.shape[0].value()` ### `num_colwise_lanes` `comptime num_colwise_lanes = SIMD[DType.uint32, 1](warp_layout.shape[0].value())` ### `num_colwise_tiles` `comptime num_colwise_tiles = score_layout_by_mma_unit.shape[0].value()` ### `num_colwise_warps` `comptime num_colwise_warps = block_layout_by_warp.shape[0].value()` ### `num_m_mmas` `comptime num_m_mmas = score_layout_by_mma_unit.shape[0].value()` ### `num_rows_per_thread` `comptime num_rows_per_thread = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles * Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows)` ### `num_rowwise_lanes` `comptime num_rowwise_lanes = SIMD[DType.uint32, 1](warp_layout.shape[1].value())` ### `num_rowwise_tiles` `comptime num_rowwise_tiles = score_layout_by_mma_unit.shape[1].value()` ### `num_rowwise_warps` `comptime num_rowwise_warps = block_layout_by_warp.shape[1].value()` ### `num_shuffles_per_row` `comptime num_shuffles_per_row = log2_floor(warp_layout.shape[1].value())` ### `row_layout` `comptime row_layout = Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_m_mmas, fragment_layout.shape[0].value())` ### `RowMaxTensorType` `comptime RowMaxTensorType = LayoutTensor[dtype, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].row_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `RowSumTensorType` `comptime RowSumTensorType = Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType` ### `rowwise_lanes_stride` `comptime rowwise_lanes_stride = SIMD[DType.uint32, 1](warp_layout.stride[1].value())` ## Methods ### `__init__` `__init__(out self)` ### `calculate_qk_max` `calculate_qk_max(self, score_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_scratch: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `calculate_qk_sum` `calculate_qk_sum(self, score_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_scratch: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `exp` `exp[start: Int = 0, stride: Int = 1](self, score_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `calculate_correction` `calculate_correction(self)` ### `update_output` `update_output(self, output_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `update_sum` `update_sum(self)` ### `update_max` `update_max(self)` ### `full` `full(self, output_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], score_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_scratch: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## softmax (Softmax)
## Structs * [​`Softmax`](./Softmax):
--- ## GlobalMemoryManager
`struct GlobalMemoryManager[dtype: DType, BM: UInt32, BN: UInt32, BK: UInt32, depth: UInt32, num_heads: UInt32, group: UInt32, token_gen: Bool, q_depth: UInt32 = depth, output_depth: UInt32 = depth]` ## Fields * ​q\_offset (`UInt32`): * ​q\_runtime\_layout (`RuntimeLayout[GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].q_gmem_layout, element_type=DType.int32, linear_idx_type=DType.int32]`): * ​output\_offset (`UInt32`): * ​output\_runtime\_layout (`RuntimeLayout[GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].output_gmem_layout, element_type=DType.int32, linear_idx_type=DType.int32]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `kv_gmem_layout` `comptime kv_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BN), Int.__init__[UInt32](depth)), IntTuple(Int.__init__[UInt32]((GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_num_heads * depth)), 1))` ### `kv_num_heads` `comptime kv_num_heads = (num_heads // group)` ### `output_gmem_layout` `comptime output_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BM), Int.__init__[UInt32](output_depth)), IntTuple(Int.__init__[UInt32]((num_heads * output_depth)), 1)) if token_gen.__invert__()._mlir_value else Layout.row_major(Int.__init__[UInt32](BM), Int.__init__[UInt32](output_depth))` ### `q_gmem_layout` `comptime q_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BM), Int.__init__[UInt32](q_depth)), IntTuple(Int.__init__[UInt32]((num_heads * q_depth)), 1)) if token_gen.__invert__()._mlir_value else Layout.row_major(Int.__init__[UInt32](BM), Int.__init__[UInt32](q_depth))` ## Methods ### `__init__` `__init__(out self, q_tile_idx: UInt32, kv_head_idx: UInt32, seq_len: Int, q_offset: UInt32, output_offset: UInt32)` ### `get_q_tensor` `get_q_tensor[qtype: DType](self, ptr: UnsafePointer[Scalar[qtype], MutAnyOrigin]) -> LayoutTensor[qtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].q_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_output_tensor` `get_output_tensor[out_type: DType](self, ptr: UnsafePointer[Scalar[out_type], MutAnyOrigin]) -> LayoutTensor[out_type, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].output_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_kv_tensor` `get_kv_tensor[kvtype: DType, //](self, ptr: UnsafePointer[Scalar[kvtype], MutAnyOrigin], kv_tile_num_rows: UInt32) -> LayoutTensor[kvtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_gmem_layout, MutAnyOrigin, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## SharedMemoryManager (Utils)
`struct SharedMemoryManager[shared_kv: Bool, full_kv: Bool, depth_padded: Bool, double_buffer: Bool, dtype: DType, BM: Int, BN: Int, BK: Int, depth: Int, token_gen: Bool]` ## Fields * ​p\_smem (`UnsafePointer[Scalar[dtype], MutExternalOrigin, address_space=AddressSpace.SHARED]`): * ​k\_smem (`UnsafePointer[Scalar[dtype], MutExternalOrigin, address_space=AddressSpace.SHARED]`): * ​v\_smem (`UnsafePointer[Scalar[dtype], MutExternalOrigin, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[dtype]()` ### `alignment` `comptime alignment = align_of[SIMD[dtype, simd_width_of[dtype]()]]()` ### `k_smem_size` `comptime k_smem_size = ((BN * depth if full_kv else BK) * 2 if double_buffer else 1)` ### `p_smem_size` `comptime p_smem_size = (BM * BN) if token_gen else 0` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `v_smem_size` `comptime v_smem_size = ((BN if full_kv else BK * pad[dtype, depth, depth]() if depth_padded else depth) * 2 if double_buffer else 1)` ## Methods ### `__init__` `__init__(out self)` ### `get_k_ptr` `get_k_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]` **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) ### `get_v_ptr` `get_v_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]` **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) ### `get_p_ptr` `get_p_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]` **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer) ### `get_warp_scratch_ptr` `get_warp_scratch_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]` **Returns:** [`UnsafePointer`](/mojo/std/memory/unsafe_pointer/UnsafePointer)
--- ## copy_dram_to_sram_lds
`copy_dram_to_sram_lds[swizzle: Optional[Swizzle] = Optional[Swizzle]()](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lds_base_ptr: UInt32 = 0)`
--- ## copy_local_to_dram2
`copy_local_to_dram2[dst_thread_layout: Layout, thread_scope: ThreadScope = ThreadScope.BLOCK](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## get_fragment_layout
`get_fragment_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## get_nested_fragment_layout
`get_nested_fragment_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## get_warp_coords
`get_warp_coords[BN: Int, WN: Int]() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList)
--- ## get_warp_layout
`get_warp_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## utils (Utils)
## `comptime` values ### `LocalLayoutTensor` `comptime LocalLayoutTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `SharedLayoutTensor` `comptime SharedLayoutTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` #### Parameters * ​dtype ([`DType`](/std/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`GlobalMemoryManager`](./GlobalMemoryManager): * [​`SharedMemoryManager`](./SharedMemoryManager): ## Functions * [​`copy_dram_to_sram_lds`](./copy_dram_to_sram_lds): * [​`copy_local_to_dram2`](./copy_local_to_dram2): * [​`get_fragment_layout`](./get_fragment_layout): * [​`get_nested_fragment_layout`](./get_nested_fragment_layout): * [​`get_warp_coords`](./get_warp_coords): * [​`get_warp_layout`](./get_warp_layout): * [​`load_b`](./load_b): * [​`load_b_`](./load_b_): * [​`load_b_tr`](./load_b_tr): Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. * [​`pad`](./pad):
--- ## load_b
`load_b[mma_shape: IndexList[3], swizzle: Optional[Swizzle]](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, Layout.row_major((layout.size() // (WARP_SIZE * 8)), 8), MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## load_b_
`load_b_[mma_shape: IndexList[3], swizzle: Optional[Swizzle], k_tile_idx: Int](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, simd_width_of[dtype]()]` **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD)
--- ## load_b_tr (Utils)
`load_b_tr[mma_shape: IndexList[3]](tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, 8]` Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. This function supports double-rate MFMA shapes (32x32x16, 16x16x32) with bfloat16 input. The input tile (shape = (mma\_shape\[2], mma\_shape\[1])) is split along the K dimension into two halves of shape (MMA\_K//2, MMA\_N). Each half is loaded using `_load_tr16_b64_warp`, which performs a transposed (column-major) load from shared memory. The resulting two 4-element SIMD vectors are concatenated into a single `SIMD[tile.dtype, 8]` vector. **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/std/utils/index_/IndexList)): The MMA instruction tile shape (only 32x32x16 or 16x16x32 supported). **Args:** * ​tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A `LayoutTensor`, residing in shared memory, with shape (mma\_shape\[2], mma\_shape\[1]) and dtype `DType.bfloat16`. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): SIMD\[tile.dtype, 8]: Concatenated transposed SIMD loads from both halves of the tile.
--- ## pad
`pad[dtype: DType, depth: Int, size: Int]() -> Int` **Returns:** [`Int`](/mojo/std/builtin/int/Int)
--- ## gpu (3)
GPU attention operations. ## Packages * [​`amd`](./amd/): AMD GPU attention operations.
--- ## attention (5)
Attention operations. ## Packages * [​`gpu`](./gpu/): GPU attention operations.
--- ## cpu_bicubic_kernel
`cpu_bicubic_kernel(output_host: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_host: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` Perform bicubic interpolation on a TileTensor of form NCHW. **Args:** * ​output\_host ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Output tensor with desired dimensions. * ​input\_host ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Input tensor of shape \[B, C, H, W].
--- ## cubic_kernel
`cubic_kernel(x: Float32) -> Float32` Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. This uses the Catmull-Rom variant (Robidoux cubic) with a = -0.75, which is what PyTorch uses in get\_cubic\_upsample\_coefficients. ([Source](https://github.com/pytorch/pytorch/blob/59eb61b2d1e4b64debbefa036acd0d8c7d55f0a3/aten/src/ATen/native/UpSample.h#L410-L423)). This also matches OpenCV's [interpolateCubic](https://github.com/opencv/opencv/blob/cf2a3c8e7430cc92569dd7f114609f9377b12d9e/modules/imgproc/src/resize.cpp#L907-L915). **Args:** * ​x ([`Float32`](/mojo/std/builtin/simd/#float32)): Distance from the center point. **Returns:** [`Float32`](/mojo/std/builtin/simd/#float32): Weight contribution based on the distance. `cubic_kernel(x: SIMD[dtype, size]) -> SIMD[dtype, size]` Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. This uses the Catmull-Rom variant (Robidoux cubic) with a = -0.75, which is what PyTorch uses in get\_cubic\_upsample\_coefficients. ([Source](https://github.com/pytorch/pytorch/blob/59eb61b2d1e4b64debbefa036acd0d8c7d55f0a3/aten/src/ATen/native/UpSample.h#L410-L423)). This also matches OpenCV's [interpolateCubic](https://github.com/opencv/opencv/blob/cf2a3c8e7430cc92569dd7f114609f9377b12d9e/modules/imgproc/src/resize.cpp#L907-L915). **Args:** * ​x ([`SIMD`](/mojo/std/builtin/simd/SIMD)): Distance from the center point. **Returns:** [`SIMD`](/mojo/std/builtin/simd/SIMD): Weight contribution based on the distance.
--- ## gpu_bicubic_kernel
`gpu_bicubic_kernel[dtype: DType, OutputLayoutType: TensorLayout, output_origin: MutOrigin, InputLayoutType: TensorLayout, input_origin: ImmutOrigin](output: TileTensor[dtype, OutputLayoutType, output_origin], input: TileTensor[dtype, InputLayoutType, input_origin])` Perform bicubic interpolation using GPU. **Args:** * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Output tensor with desired dimensions on the device. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Input tensor of shape \[B, C, H, W] on the device.
--- ## bicubic
This module provides CPU and GPU implementations for bicubic interpolation. Bicubic interpolation is a 2D extension of cubic interpolation for resampling digital images. It uses the weighted average of the 4x4 neighborhood of pixels around the target location to compute the interpolated value. ## Functions * [​`cpu_bicubic_kernel`](./cpu_bicubic_kernel): Perform bicubic interpolation on a TileTensor of form NCHW. * [​`cubic_kernel`](./cubic_kernel): Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. * [​`gpu_bicubic_kernel`](./gpu_bicubic_kernel): Perform bicubic interpolation using GPU. * [​`map_output_to_input_coord`](./map_output_to_input_coord): Map output pixel coordinate to input coordinate using center alignment. This implements the standard coordinate mapping for image resizing: input\_coord = (output\_coord + 0.5) \* scale - 0.5 The +0.5 and -0.5 terms ensure pixel centers are aligned properly. Args: output\_coord: Output pixel coordinate. scale: Scale factor (input\_size / output\_size). Returns: Corresponding input coordinate as a float. * [​`resize_bicubic`](./resize_bicubic): Perform bicubic interpolation.
--- ## map_output_to_input_coord
`map_output_to_input_coord(output_coord: Int, scale: Float32) -> Float32` Map output pixel coordinate to input coordinate using center alignment. This implements the standard coordinate mapping for image resizing: input\_coord = (output\_coord + 0.5) \* scale - 0.5 The +0.5 and -0.5 terms ensure pixel centers are aligned properly. Args: output\_coord: Output pixel coordinate. scale: Scale factor (input\_size / output\_size). Returns: Corresponding input coordinate as a float. **Returns:** [`Float32`](/mojo/std/builtin/simd/#float32)
--- ## resize_bicubic
`resize_bicubic[dtype: DType, //, target: StringSlice[StaticConstantOrigin]](output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContextPtr)` Perform bicubic interpolation. **Args:** * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Output tensor with desired dimensions on host or device. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): Input tensor of shape \[B, C, H, W] on host or device. * ​ctx ([`DeviceContextPtr`](/mojo/std/runtime/asyncrt/DeviceContextPtr)): Device context to enqueue GPU kernels on.
--- ## broadcast (3)
`broadcast[dtype: DType](output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types])` For each axis of `input`, if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. **Args:** * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The output buffer. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input buffer.
--- ## broadcast_impl
`broadcast_impl[dtype: DType](axis: Int, output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_prev_axis_stride: Int, output_prev_axis_stride: Int, input_offset: Int, output_offset: Int, rightmost_broadcast_axis: Int)` For each axis of `input` ∈ \[axis, rank), if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. **Args:** * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis value. * ​output ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The output buffer. * ​input ([`TileTensor`](/mojo/kernels/layout/_tile_tensor/TileTensor)): The input buffer. * ​input\_prev\_axis\_stride ([`Int`](/mojo/std/builtin/int/Int)): The stride at axis `axis - 1` for input. * ​output\_prev\_axis\_stride ([`Int`](/mojo/std/builtin/int/Int)): The stride at axis `axis - 1` for output. * ​input\_offset ([`Int`](/mojo/std/builtin/int/Int)): The offset at which we start copying data from. * ​output\_offset ([`Int`](/mojo/std/builtin/int/Int)): The offset at which we start copying data to. * ​rightmost\_broadcast\_axis ([`Int`](/mojo/std/builtin/int/Int)): The largest axis at which we need to duplicate `input` data.
--- ## broadcast (4)
## Functions * [​`broadcast`](./broadcast): For each axis of `input`, if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. * [​`broadcast_impl`](./broadcast_impl): For each axis of `input` ∈ \[axis, rank), if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`.
--- ## concat (Concat)
`concat[input_origin: ImmutOrigin, InputLayoutType: TensorLayout, //, dtype: DType, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu", epilogue_fn: Optional[elementwise_epilogue_type] = None](output: TileTensor[dtype, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], axis: Int, inputs: StaticTuple[TileTensor[dtype, InputLayoutType, input_origin], size], context: DeviceContextPtr = DeviceContextPtr())`
--- ## concat_shape
`concat_shape[input_origin: ImmutOrigin, InputLayoutType: TensorLayout, //, input_type: DType, single_thread_blocking_override: Bool](input_bufs: List[TileTensor[input_type, InputLayoutType, input_origin]], axis: Int) -> IndexList[InputLayoutType.rank]` Compute the output shape of a `pad` operation, and assert the inputs are compatible. **Parameters:** * ​input\_origin ([`ImmutOrigin`](/mojo/std/builtin/type_aliases/#immutorigin)): Origin of the input tensor. * ​InputLayoutType ([`TensorLayout`](/mojo/kernels/layout/_layout/TensorLayout)): Layout type of the input tensor. * ​input\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the input tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input\_bufs ([`List`](/mojo/std/collections/list/List)): The input tensors list. * ​axis ([`Int`](/mojo/std/builtin/int/Int)): The axis. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## fused_concat
`fused_concat[dtype: DType, rank: Int, single_thread_blocking_override: Bool, input_fn: fn[input_index: Int, width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, width], output_0_fn: elementwise_epilogue_type, target: StringSlice[StaticConstantOrigin] = "cpu"](axis: Int, input_shapes: StaticTuple[IndexList[rank], size], output: TileTensor[dtype, LayoutType, origin], ctx: DeviceContextPtr)`
--- ## concat (3)
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[c_type: DType, rank: Int, width: Int = 1, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None` ## Functions * [​`concat`](./concat): * [​`concat_shape`](./concat_shape): Compute the output shape of a `pad` operation, and assert the inputs are compatible. * [​`fused_concat`](./fused_concat): * [​`memcpy_or_fuse`](./memcpy_or_fuse):
--- ## memcpy_or_fuse
`memcpy_or_fuse[rank: Int, dtype: DType, epilogue_fn: Optional[elementwise_epilogue_type]](dest_data: UnsafePointer[Int8, origin], out_byte_offset: Int, src_data: UnsafePointer[Int8, origin], n: Int, out_shape: IndexList[rank, element_type=element_type])`
--- ## ConvDirectNHWC
`struct ConvDirectNHWC[input_mut: Bool, filter_mut: Bool, conv_attr_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_origin: Origin[mut=input_mut], filter_origin: Origin[mut=filter_mut], output_origin: MutOrigin, input_type: DType, filter_type: DType, output_type: DType, filter_packed: Bool, conv_attr: ConvInfoStatic[conv_attr_rank], elementwise_epilogue: Optional[elementwise_epilogue_type] = None]` Implement the outer loops for direct convolution. Collapse N, HO, WO into one dimension n\_ho\_wo. Tile n\_ho\_wo, C, and F. The tile factor for C and F are chosen by a heuristic prioritizing C. n\_ho\_wo is tiled by micro kernel's height. If n\_ho\_wo is large enough to spill LLC, we may need to tile n\_ho\_wo as the outer most loop with a factor fit in LLC. Assume F is divisible at least by simd\_size. ## Fields * ​output (`LayoutTensor[output_type, output_layout, output_origin]`): * ​input (`LayoutTensor[input_type, input_layout, input_origin]`): * ​filter (`LayoutTensor[filter_type, filter_layout, filter_origin]`): * ​conv\_shape (`ConvShape[conv_attr_rank]`): * ​partition (`ConvPartition`): * ​cf\_tile\_size (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `packed_and_fully_static` `comptime packed_and_fully_static = filter_packed if filter_layout.shape.all_known() if output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else filter_layout.shape.all_known() if output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]()` ## Methods ### `run` `static run(output: LayoutTensor[output_type, output_layout, output_origin], input: LayoutTensor[input_type, input_layout, input_origin], filter: LayoutTensor[filter_type, filter_layout, filter_origin], conv_shape: ConvShape[conv_attr_rank])` ### `is_new_c_accum` `is_new_c_accum(self, c_idx: Int) -> Bool` **Returns:** [`Bool`](/mojo/std/builtin/bool/Bool) ### `update_output_tile_no_padding` `update_output_tile_no_padding[micro_kernel_height: Int, micro_kernel_width: Int, c_fully_cached: Bool, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int, output_flat_coord: Int)` ### `output_space_flat_loop` `output_space_flat_loop[micro_kernel_f_size: Int, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int)` ### `output_space_loop` `output_space_loop[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int)` ### `output_space_loop_1d` `output_space_loop_1d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `output_space_loop_2d` `output_space_loop_2d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `output_space_loop_3d` `output_space_loop_3d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)`
--- ## CuDNNConvMeta
`@register_passable` `struct CuDNNConvMeta` ## Fields * ​ptr\_handle (`LegacyUnsafePointer[cudnnContext]`): * ​ptr\_input\_desc (`LegacyUnsafePointer[cudnnTensorStruct]`): * ​ptr\_filter\_desc (`LegacyUnsafePointer[cudnnFilterStruct]`): * ​ptr\_conv\_desc (`LegacyUnsafePointer[cudnnConvolutionStruct]`): * ​ptr\_output\_desc (`LegacyUnsafePointer[cudnnTensorStruct]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable), [`RegisterPassable`](/mojo/std/builtin/value/RegisterPassable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` ### `__del__` `__del__(deinit self)`
--- ## Naive2dConvolution
`struct Naive2dConvolution[output_type: DType, input_type: DType, filter_type: DType]` Struct wrapper for naive 2d convolution implementation. ## Fields * ​output (`LegacyUnsafePointer[Scalar[output_type]]`): * ​input (`LegacyUnsafePointer[Scalar[input_type]]`): * ​filter (`LegacyUnsafePointer[Scalar[filter_type]]`): * ​pad\_d (`IndexList[2]`): * ​pad\_h (`IndexList[2]`): * ​pad\_w (`IndexList[2]`): * ​stride (`IndexList[3]`): * ​dilation (`IndexList[3]`): * ​num\_groups (`Int`): * ​output\_shape (`IndexList[5]`): * ​input\_shape (`IndexList[5]`): * ​filter\_shape (`IndexList[5]`): ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`Copyable`](/mojo/std/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/std/builtin/value/ImplicitlyCopyable), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible), [`Movable`](/mojo/std/builtin/value/Movable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, output: LegacyUnsafePointer[Scalar[output_type]], input: LegacyUnsafePointer[Scalar[input_type]], filter: LegacyUnsafePointer[Scalar[filter_type]], output_shape: IndexList[5], input_shape: IndexList[5], filter_shape: IndexList[5], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], stride: IndexList[3], dilation: IndexList[3], num_groups: Int)` ### `run` `static run(output: LegacyUnsafePointer[Scalar[output_type]], input: LegacyUnsafePointer[Scalar[input_type]], filter: LegacyUnsafePointer[Scalar[filter_type]], output_shape: IndexList[5], input_shape: IndexList[5], filter_shape: IndexList[5], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], stride: IndexList[3], dilation: IndexList[3], num_groups: Int)`
--- ## accumulate_wo_tile_1d
`accumulate_wo_tile_1d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, S: Int, mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: Int, filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: Int, partial_load_filter_size: Int, w: Int, W: Int, dilation: Int)` Update one row in the output for a given (c, f) tile. **Parameters:** * ​micro\_kernel\_height ([`Int`](/mojo/std/builtin/int/Int)): Number of input points in register tiling. * ​micro\_kernel\_width ([`Int`](/mojo/std/builtin/int/Int)): Number of SIMD resgiters assigned to F. * ​simd\_size ([`Int`](/mojo/std/builtin/int/Int)): Number of elements in a SIMD register. * ​partial\_load\_filter ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether using partial load for filter. * ​effected\_by\_padding ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether the tile is effected by padding. * ​input\_dt ([`DType`](/mojo/std/builtin/dtype/DType)): DType of input. * ​filter\_dt ([`DType`](/mojo/std/builtin/dtype/DType)): DType of filter. **Args:** * ​c\_tile\_size ([`Int`](/mojo/std/builtin/int/Int)): Tile size in input channel. * ​S ([`Int`](/mojo/std/builtin/int/Int)): Filter window width. * ​acc ([`_Accumulator`](/mojo/kernels/linalg/accumulate/_Accumulator)): Pointer to register tile accumulator. * ​input ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to the first input point in WO tile. * ​input\_stride ([`Int`](/mojo/std/builtin/int/Int)): Stride between two input points, i.e., C w/ NHWC layout. * ​input\_stride\_to\_nbr ([`Int`](/mojo/std/builtin/int/Int)): Stride between an input point and its neighbor. * ​filter ([`LegacyUnsafePointer`](/mojo/std/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to the first coef in the filter window. * ​filter\_stride ([`Int`](/mojo/std/builtin/int/Int)): Stride between two segments of size `micro_kernel_width * simd_size`. * ​filter\_stride\_to\_nbr ([`Int`](/mojo/std/builtin/int/Int)): Stride between between two neighbor coefs, i.e., CF w/ RSCF layout. * ​partial\_load\_filter\_size ([`Int`](/mojo/std/builtin/int/Int)): Size of partial load for filter. * ​w ([`Int`](/mojo/std/builtin/int/Int)): Coordinate in an input row. * ​W ([`Int`](/mojo/std/builtin/int/Int)): Input width. * ​dilation ([`Int`](/mojo/std/builtin/int/Int)): Convolution dilation.
--- ## accumulate_wo_tile_2d
`accumulate_wo_tile_2d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, RS: IndexList[2], mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: IndexList[2], filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: IndexList[2], partial_load_filter_size: Int, hw: IndexList[2], HW: IndexList[2], dilation: IndexList[2])`
--- ## accumulate_wo_tile_3d
`accumulate_wo_tile_3d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, QRS: IndexList[3], mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: IndexList[3], filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: IndexList[3], partial_load_filter_size: Int, dhw: IndexList[3], DHW: IndexList[3], dilation: IndexList[3])`
--- ## check_cudnn_error
`check_cudnn_error(stat: cudnnStatus_t)`
--- ## conv1d_update_wo_tile
`conv1d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: Optional[elementwise_epilogue_type] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[rank], n: Int, wo: Int)`
--- ## conv2d_gpu_naive_nhwc_rscf
`conv2d_gpu_naive_nhwc_rscf[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, block_size: Int, maybe_epilogue_func: Optional[elementwise_simd_epilogue_type]](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[2], dilation: IndexList[2], padding: IndexList[2])`
--- ## conv2d_update_wo_tile
`conv2d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: Optional[elementwise_epilogue_type] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[2], n: Int, howo: IndexList[2])`
--- ## conv3d_gpu_naive_ndhwc_qrscf
`conv3d_gpu_naive_ndhwc_qrscf[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, block_size: Int, maybe_epilogue_func: Optional[elementwise_simd_epilogue_type]](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[3], dilation: IndexList[3], padding: IndexList[3])`
--- ## conv3d_update_wo_tile
`conv3d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: Optional[elementwise_epilogue_type] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[3], n: Int, dohowo: IndexList[3])`
--- ## conv_cudnn
`conv_cudnn[input_type: DType, filter_type: DType, output_type: DType](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[2], dilation: IndexList[2], padding: IndexList[2], num_groups: Int, ctx: DeviceContext)`
--- ## conv_gpu
`conv_gpu[conv_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, maybe_epilogue_func: Optional[elementwise_simd_epilogue_type] = None, filter_is_fcrs: Bool = False](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[conv_rank], dilation: IndexList[conv_rank], padding: IndexList[(2 * conv_rank)], num_groups: Int, ctx: DeviceContext)`
--- ## conv_nhwc_direct
`conv_nhwc_direct[conv_info_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, filter_packed: Bool, conv_info_static: ConvInfoStatic[conv_info_rank], lambdas_have_fusion: Bool, elementwise_lambda: elementwise_simd_epilogue_type](input: LayoutTensor[input_type, input_layout, origin], filter: LayoutTensor[filter_type, filter_layout, origin], output: LayoutTensor[output_type, output_layout, origin], stride: IndexList[conv_info_rank], dilation: IndexList[conv_info_rank], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], num_groups: Int)`
--- ## conv_shape
`conv_shape[input_type: DType, filter_type: DType, strides_type: DType, dilations_type: DType, paddings_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter_buf: LayoutTensor[filter_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides_buf: LayoutTensor[strides_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations_buf: LayoutTensor[dilations_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups_scalar: Scalar[dtype]) -> IndexList[LayoutTensor[input_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `conv` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the input tensor. * ​filter\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the filter tensor. * ​strides\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the strides tensor. * ​dilations\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the dilations tensor. * ​paddings\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Type of the paddings tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, then the operation is run ssynchronouslysing a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​filter\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter tensor. * ​strides\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The strides tensor. * ​dilations\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The dilations tensor. * ​paddings\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The paddings tensor. * ​num\_groups\_scalar ([`Scalar`](/mojo/std/builtin/simd/#scalar)): The num\_groups scalar. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## get_cudnn_dtype
`get_cudnn_dtype[dtype: DType]() -> cudnnDataType_t` Map Mojo DType to cuDNN data type. Support only floating point dtypes for now. **Returns:** `cudnnDataType_t`
--- ## conv (Conv)
## `comptime` values ### `OpaquePointer` `comptime OpaquePointer = LegacyUnsafePointer[NoneType]` ### `UnsafePointer` `comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]` ## Structs * [​`ConvDirectNHWC`](./ConvDirectNHWC): Implement the outer loops for direct convolution. Collapse N, HO, WO into one dimension n\_ho\_wo. Tile n\_ho\_wo, C, and F. The tile factor for C and F are chosen by a heuristic prioritizing C. n\_ho\_wo is tiled by micro kernel's height. * [​`CuDNNConvMeta`](./CuDNNConvMeta): * [​`Naive2dConvolution`](./Naive2dConvolution): Struct wrapper for naive 2d convolution implementation. ## Functions * [​`accumulate_wo_tile_1d`](./accumulate_wo_tile_1d): Update one row in the output for a given (c, f) tile. * [​`accumulate_wo_tile_2d`](./accumulate_wo_tile_2d): * [​`accumulate_wo_tile_3d`](./accumulate_wo_tile_3d): * [​`check_cudnn_error`](./check_cudnn_error): * [​`conv1d_update_wo_tile`](./conv1d_update_wo_tile): * [​`conv2d_gpu_naive_nhwc_rscf`](./conv2d_gpu_naive_nhwc_rscf): * [​`conv2d_update_wo_tile`](./conv2d_update_wo_tile): * [​`conv3d_gpu_naive_ndhwc_qrscf`](./conv3d_gpu_naive_ndhwc_qrscf): * [​`conv3d_update_wo_tile`](./conv3d_update_wo_tile): * [​`conv_cudnn`](./conv_cudnn): * [​`conv_gpu`](./conv_gpu): * [​`conv_nhwc_direct`](./conv_nhwc_direct): * [​`conv_shape`](./conv_shape): Compute the output shape of a `conv` operation, and assert the inputs are compatible. * [​`get_cudnn_dtype`](./get_cudnn_dtype): Map Mojo DType to cuDNN data type. * [​`pack_conv_filter_shape`](./pack_conv_filter_shape): Compute the output shape of convolution filter packing. * [​`pack_filter`](./pack_filter): This packs the filter form RSCF to FRSCf. Use the default micro kernel size for dynamic shapes. * [​`pack_filter_shape`](./pack_filter_shape): Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. * [​`pack_filter_shape_impl`](./pack_filter_shape_impl): Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel.
--- ## pack_conv_filter_shape
`pack_conv_filter_shape[single_thread_blocking_override: Bool](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int) -> IndexList[(LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank + 1)]` Compute the output shape of convolution filter packing. **Parameters:** * ​single\_thread\_blocking\_override ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter to be packed. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): The number of groups in the convolution. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## pack_filter
`pack_filter(filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], packed_filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int)` This packs the filter form RSCF to FRSCf. Use the default micro kernel size for dynamic shapes. `pack_filter[simd_size: Int, micro_kernel_f_size: Int](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], packed_filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int)` This packs the filter form RSCF to FRSCf. F is first broken down to segments of size micro\_kernel\_f\_size, then the remainder is further divided by simd\_size. The last residual elements if any is padded with zero to fill simd\_size. **Parameters:** * ​simd\_size ([`Int`](/mojo/std/builtin/int/Int)): Can differ from the simd size of the input type. * ​micro\_kernel\_f\_size ([`Int`](/mojo/std/builtin/int/Int)): The size of the last dimension in FRSCf, which is equals the size of the micro kernel's F dimension. **Args:** * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Filter in RSCF layout (if 2D). * ​packed\_filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Packed filter in FRSCf layout (if 2D). F - the index of continuous segments in micro kernel. R, S, C - original R, S, C. f - the index within a continuous segments. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): The number of groups in the convolution.
--- ## pack_filter_shape
`pack_filter_shape[filter_type: DType, input_shape: DimList, filter_shape: DimList, output_shape: DimList, strides: DimList, dilations: DimList, paddings: DimList, num_groups: Int, single_thread_blocking_override: Bool](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[(LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank + 1)]` Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## pack_filter_shape_impl
`pack_filter_shape_impl[filter_type: DType](Q: Int, R: Int, S: Int, C: Int, F: Int, num_groups: Int) -> IndexList[6]` Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. **Args:** * ​Q ([`Int`](/mojo/std/builtin/int/Int)): Original Q filter dimension. * ​R ([`Int`](/mojo/std/builtin/int/Int)): Original R filter dimension. * ​S ([`Int`](/mojo/std/builtin/int/Int)): Original S filter dimension. * ​C ([`Int`](/mojo/std/builtin/int/Int)): Original C filter dimension. * ​F ([`Int`](/mojo/std/builtin/int/Int)): Original F filter dimension. * ​num\_groups ([`Int`](/mojo/std/builtin/int/Int)): Number of groups in the convolution. **Returns:** [`IndexList`](/mojo/std/utils/index_/IndexList): The output shape.
--- ## conv2d_fprop
`conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], problem: Conv2dProblemShape, ctx: DeviceContext)` Launch Conv2D forward propagation with 4D NHWC API and implicit im2col. This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers. The convolution is implemented as implicit GEMM: * Activation matrix A\[M, K] where M = batch*H\_out*W\_out, K = C*R*S * Filter matrix B\[K, N] where N = out\_channels (transposed) * Output matrix C\[M, N] The TMA im2col descriptor handles the linear K iteration by decomposing k\_coord into (channel, filter\_r, filter\_s) using the corner parameters: * lower\_corner defines the starting filter offset (negative for padding) * upper\_corner defines the ending filter offset * channels\_per\_pixel is the number of input channels (C) * pixels\_per\_column is the output spatial tile size (BM) **Parameters:** * ​act\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the input activation tensor. * ​filter\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the filter weights tensor. * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the output tensor. * ​config ([`Conv2dConfig`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dConfig)): Kernel configuration (tile sizes, pipeline stages, etc.). * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional element-wise lambda function for epilogue fusion (bias add, activation, residual connection). Signature: `fn(coords: IndexList[2], val: SIMD) -> SIMD`. * ​register\_based\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, apply lambda in registers (faster). If False, apply lambda after SMEM write (more flexible). **Args:** * ​output ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output tensor \[N, H\_out, W\_out, C\_out] in NHWC layout. * ​activation ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Input activation \[N, H, W, C] in NHWC layout. * ​filter ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Filter weights \[K, R, S, C] in KRSC layout. * ​problem ([`Conv2dProblemShape`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dProblemShape)): Convolution problem shape specification. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel launch. **Raises:** Error if kernel launch fails or constraints are violated.
--- ## conv2d_fprop_with_residual
`conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], source: NDBuffer[out_type, 4, origin], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)` Launch Conv2D fprop with residual add. Computes D = Conv(A,B) + beta\*C. This function extends conv2d\_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance. The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta \* C This supports common patterns like: * Skip connections: D = Conv(A,B) + C (beta=1.0) * Residual scaling: D = Conv(A,B) + 0.5\*C (beta=0.5) * Fused residual+activation: D = ReLU(Conv(A,B)) + C Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has\_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d\_fprop. **Parameters:** * ​act\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the input activation tensor. * ​filter\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the filter weights tensor. * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Data type of the output tensor. * ​config ([`Conv2dConfig`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dConfig)): Kernel configuration (tile sizes, pipeline stages, etc.). * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional element-wise lambda function for epilogue fusion (bias add, activation). Applied before residual. * ​register\_based\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, apply lambda in registers (faster). * ​has\_residual ([`Bool`](/mojo/std/builtin/bool/Bool)): If True, apply residual add. If False, source is ignored. **Args:** * ​output ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output tensor \[N, H\_out, W\_out, C\_out] in NHWC layout (D). * ​activation ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Input activation \[N, H, W, C] in NHWC layout (A). * ​filter ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Filter weights \[K, R, S, C] in KRSC layout (B). * ​source ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Source tensor \[N, H\_out, W\_out, C\_out] for residual (C). * ​beta ([`Float32`](/mojo/std/builtin/simd/#float32)): Residual scale factor. If 0.0, no residual is applied. * ​problem ([`Conv2dProblemShape`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dProblemShape)): Convolution problem shape specification. * ​ctx ([`DeviceContext`](/mojo/std/gpu/host/device_context/DeviceContext)): Device context for kernel launch. **Raises:** Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.
--- ## im2col
`im2col[dtype: DType](output: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], activation: NDBuffer[dtype, 4, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], problem: Conv2dProblemShape)` Explicit im2col transformation for convolution. Transforms a 4D activation tensor \[N, H, W, C] into a 2D matrix \[M, K] for GEMM-based convolution. M = batch \* out\_h \* out\_w K = in\_channels \* filter\_h \* filter\_w Note: This is a CPU reference implementation. For production use, the implicit im2col in the kernel is preferred. **Args:** * ​output ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output 2D buffer \[M, K]. * ​activation ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Input 4D buffer \[N, H, W, C]. * ​problem ([`Conv2dProblemShape`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dProblemShape)): Convolution problem shape.
--- ## conv2d
Public API for SM100 Conv2D forward propagation. This module provides the high-level API for launching Conv2D fprop kernels on NVIDIA Blackwell (SM100) GPUs. It handles: * TMA descriptor setup for activation (with im2col), filter, and output * Kernel configuration selection * Kernel launch with proper grid/block dimensions Usage (4D NHWC API with implicit im2col): from nn.conv\_sm100 import conv2d\_fprop ``` var problem = Conv2dProblemShape( batch=1, in_height=256, in_width=256, in_channels=64, out_channels=128, filter_h=3, filter_w=3, pad_h=1, pad_w=1, ) conv2d_fprop(output, input, filter, problem, ctx) ``` Note: This implementation currently supports: * stride=1, dilation=1 * NHWC layout for activation and output * KRSC layout for filter * BF16/FP16 data types ## Functions * [​`conv2d_fprop`](./conv2d_fprop): Launch Conv2D forward propagation with 4D NHWC API and implicit im2col. * [​`conv2d_fprop_with_residual`](./conv2d_fprop_with_residual): Launch Conv2D fprop with residual add. * [​`im2col`](./im2col): Explicit im2col transformation for convolution.
--- ## Conv2dFpropKernel
`struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, act_layout: Layout, filter_layout: Layout, out_layout: Layout, act_desc_layout: Layout, filter_desc_layout: Layout, out_desc_layout: Layout, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, src_layout: Layout = out_layout, src_desc_layout: Layout = out_desc_layout]` SM100 Conv2D forward propagation kernel. This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation. The kernel structure: * Scheduler warp: CLC-based tile scheduling * Load warp: TMA loads with im2col transformation * MMA warp: Tensor core operations * Epilogue warps: Output from TMEM to GMEM ## Parameters * ​act\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Activation data type. * ​filter\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Filter data type. * ​out\_type ([`DType`](/mojo/std/builtin/dtype/DType)): Output data type. * ​act\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory activation layout. * ​filter\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory filter layout. * ​out\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory output layout. * ​act\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout for activation. * ​filter\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout for filter. * ​out\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout for output. * ​config ([`Conv2dConfig`](/mojo/kernels/nn/conv_sm100/conv_config/Conv2dConfig)): Kernel configuration. * ​cluster\_shape ([`StaticTuple`](/mojo/std/utils/static_tuple/StaticTuple)): CUDA cluster dimensions. * ​elementwise\_compute\_lambda\_fn ([`Optional`](/mojo/std/collections/optional/Optional)): Optional epilogue lambda for fusion (bias add, activation functions, residual connections). * ​register\_based\_epilogue ([`Bool`](/mojo/std/builtin/bool/Bool)): Whether to apply the lambda in registers. * ​src\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Global memory layout for source C (residual input). * ​src\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): TMA descriptor layout for source C. ## Implemented traits [`AnyType`](/mojo/std/builtin/anytype/AnyType), [`ImplicitlyDestructible`](/mojo/std/builtin/anytype/ImplicitlyDestructible) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_layout` `comptime accum_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_N)` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = Conv2dConfig.accum_type[act_type, filter_type, out_type]()` ### `AccumTensor` `comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]` ### `act_expected_bytes` `comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.act_smem_layout.size() * size_of[act_type]())` ### `act_tma_load_size` `comptime act_tma_load_size = act_desc_layout.size()` ### `act_tma_rows` `comptime act_tma_rows = act_desc_layout.shape[0].value()` ### `ActTileLoaderTypeIm2col` `comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputN)` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS + (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_SIZE * (((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)))` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)` ### `CLUSTER_N` `comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `epi_load_consumer_arv_count` `comptime epi_load_consumer_arv_count = SIMD[DType.int32, 1](Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)` ### `epi_load_producer_arv_count` `comptime epi_load_producer_arv_count = 1` ### `EpiLoadPipelineType` `comptime EpiLoadPipelineType = EpiLoadPipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_epi_load_stages]` ### `EPILOGUE_LOAD_THREADS` `comptime EPILOGUE_LOAD_THREADS = WARP_SIZE` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueCtx` `comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]` ### `filter_expected_bytes` `comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.filter_smem_layout.size() * size_of[filter_type]())` ### `filter_tma_load_size` `comptime filter_tma_load_size = filter_desc_layout.size()` ### `filter_tma_rows` `comptime filter_tma_rows = filter_desc_layout.shape[0].value()` ### `FilterTileLoaderType` `comptime FilterTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]` ### `input_expected_bytes` `comptime input_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].act_expected_bytes + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].filter_expected_bytes)) * config)` ### `InputTilePipelineType` `comptime InputTilePipelineType = InputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaCtx` `comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]` ### `MmaEpilogueSync` `comptime MmaEpilogueSync = WarpGroupBarrier[(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout,