# Quantize your graph weights

Quantization is an optimization technique that reduces the numeric precision of weights in a model. For example, models are usually trained with float32 weights, but you can quantize the values to a lower precision type such as int8 or int4. That is, instead of storing each scalar value with 32-bits, you can use just 8 or 4 bits. This reduces the computational and memory demands during inference, which makes the model faster and compatible with more systems.

To support quantization with MAX Graph, we’ve built an API designed for low-level graph engineers who want to quantize specific weights in a model. This API does not quantize an entire model. Like the MAX Graph API, this is a low-level API meant for engineers who want to build high-performance graphs in a systems programming language—specifically, in Mojo.

If you just want to read some code, check out the Quantize TinyStories pipeline, which quantizes a 15-million parameter version of Llama 2 with Q4_0 (4-bit) encoding.

This is post-training quantization. The Graph API does not support model
training, so you must import your model weights, load them as
`Tensor`

values, and then quantize them.

## Overview

When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision.

To support this mixed-precision strategy, the quantization API in MAX Graph is declarative. That means you can quantize the weights in your model explicitly as you see fit, rather than pick one quantization format for the whole model. You can quantize different weights with different encodings, write custom ops that understand your quantizations, and even implement your own quantization encodings.

The primary API is the
`quantize()`

function (from the `QuantizationEncoding`

trait), which takes a float32 tensor
and returns a quantized tensor as a uint8 bytes buffer (it’s a type-erased blob
of bytes that can be in any quantization encoding). You can call `quantize()`

using one of the existing quantization
encodings, such as Q4_0,
Q4_K, and Q6_K (from GGML). Then, add the quantized tensor as a node in your
graph.

Because the quantized data is just a blob of bytes with a special encoding for the values and scaling factor, any op that you pass this data into must know how to dequantize that data in order to perform its calculation with the full-precision float32 value.

Currently, the only op included in MAX Graph that can operate on quantized data
is `qmatmul()`

. This
takes a float32 tensor and a quantized tensor, and returns a float32 tensor.
This op alone allows you to build a variety of quantized transformer models.
However, using quantized weights with any other op in
`max.graph.ops`

doesn’t work as is, because
they all expect float32 inputs. To make it work, you can create a custom
op that accepts a quantized input,
dequantizes it with the appropriate decoding, and then completes the operation.

Now let’s look at some simple code examples.

## Quantize some weights

When you build a graph with MAX Graph, each batch of weights begins as a
`Tensor`

that you set in the graph as a
constant (a node created with
`Graph.constant()`

). When you
want to quantize those weights, just pass the
`Tensor`

to the `quantize()`

method from
the encoding type you want to use before you add it to the graph.

For example, the following code quantizes a tensor with
`Q4_0Encoding`

(4-bit encoding), performs quantized matmul (using
`qmatmul()`

), and prints the
results:

`from tensor import Tensor, TensorShape`

from max.engine import InferenceSession

from max.graph import Graph, TensorType

from max.graph.quantization import Q4_0Encoding

from max.graph.ops.quantized_ops import qmatmul

def main():

graph = Graph(TensorType(DType.float32, 32, 64))

# Perform matmul with the full-precision constant

# constant_value = Tensor[DType.float32](TensorShape(64, 32), 0.15)

# constant = graph.constant(constant_value)

# matmul = graph[0] @ constant

# Perform matmul with the quantized constant (transposed)

constant_value = Tensor[DType.float32](TensorShape(32, 64), 0.15)

quantized_value = Q4_0Encoding.quantize(constant_value)

quantized_constant = graph.constant(quantized_value)

matmul = qmatmul[Q4_0Encoding](graph[0], quantized_constant)

graph.output(matmul)

session = InferenceSession()

model = session.load(graph)

input = Tensor[DType.float32](TensorShape(32, 64), 0.5)

results = model.execute("input0", input^)

output = results.get[DType.float32]("output0")

print(output)

`qmatmul()`

expects the
"right-hand-side" argument to be transposed. For example, whereas the normal
`matmul()`

op takes `lhs`

and `rhs`

with shapes $m × n$ and $n × p$,
respectively (to get an output shape $m × p$), `qmatmul()`

requires the `rhs`

shape be $p × n$.

In the above example, the input shape is `[32, 64]`

and the quantized shape is
also `[32, 64]`

, making the output shape `[32, 32]`

.

You probably noticed this code also includes the full-precision matmul as an option. If you toggle the comments on lines 12-14 and 17-20, and run it again, you can see for yourself how close the results are even though the quantized constant uses just 1/8th of the memory (4-bits vs 32-bits).

No matter which quantization encoding you choose, the `quantize()`

method works
the same—it takes in a full-precision value as a `Tensor`

value and returns the
quantized value as a `Tensor`

.

Alternatively, you can use
`Graph.quantize()`

to combine
these two lines:

` quantized_value = Q4_0Encoding.quantize(constant_value)`

quantized_constant = graph.constant(quantized_value)

Into one line:

` quantized_constant = graph.quantize[Q4_0Encoding](constant_value)`

To see how we quantized a real model with this API, check out the Quantize TinyStories pipeline, which is a 15-million parameter model quantized with 4-bit encoding down to about 10MB.

Because the Graph API builds a static computation graph, quantization
happens at graph build time. That means you can’t use `quantize()`

with runtime
inputs, because all the tensors you want to quantize must be fixed at the time
the `Graph`

calls execute.

## Save and load tensors to disk

To avoid quantizing your weights every time you load a model, you can save and
load them from disk using the
`save()`

and
`load()`

functions. For
example:

`from max.graph.checkpoint import load, save, TensorDict`

from tensor import Tensor, TensorShape

def write_to_disk():

tensors = TensorDict()

tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))

tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))

save(tensors, "/path/to/checkpoint.max")

def read_from_disk():

tensors = load("/path/to/checkpoint.max")

x = tensors.get[DType.int32]("x")

The `TensorDict`

type is just a dictionary type for named tensors.

## Custom quantization encodings

If you want a quantization encoding that’s not provided in the `quantization`

package already, you can implement your own by building a type that conforms to
the
`QuantizationEncoding`

trait.

Your custom `QuantizationEncoding`

type must implement the `quantize()`

function, which takes a float32 tensor and returns a uint8 byte tensor in the
corresponding quantized buffer shape. With that type defined, you can call the
`quantize()`

function to produce your quantized tensors as you build the graph,
but this doesn’t take care of the runtime decoding.

To decode your quantization type during inference, you also need to build
custom ops for each op in your graph that takes in any of these quantized
values. For example, if your graph performs matrix-multiplication using
quantized inputs, you’ll need to implement a custom version of `matmul`

that
knows how to decode your custom quantization encoding and then use that custom
op instead of the traditional `matmul`

op.

For more information, see how to create a custom op in MAX Graph.