Quantize your graph weights
Quantization is an optimization technique that reduces the numeric precision of weights in a model. For example, models are usually trained with float32 weights, but you can quantize the values to a lower precision type such as int8 or int4. That is, instead of storing each scalar value with 32-bits, you can use just 8 or 4 bits. This reduces the computational and memory demands during inference, which makes the model faster and compatible with more systems.
To support quantization with MAX Graph, we’ve built an API designed for low-level graph engineers who want to quantize specific weights in a model. This API does not quantize an entire model. Like the MAX Graph API, this is a low-level API meant for engineers who want to build high-performance graphs in a systems programming language—specifically, in Mojo.
If you just want to read some code, check out the Quantize TinyStories pipeline, which quantizes a 15-million parameter version of Llama 2 with Q4_0 (4-bit) encoding.
This is post-training quantization. The Graph API does not support model
training, so you must import your model weights, load them as
Tensor
values, and then quantize them.
Overview
When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision.
To support this mixed-precision strategy, the quantization API in MAX Graph is declarative. That means you can quantize the weights in your model explicitly as you see fit, rather than pick one quantization format for the whole model. You can quantize different weights with different encodings, write custom ops that understand your quantizations, and even implement your own quantization encodings.
The primary API is the
quantize()
function (from the QuantizationEncoding
trait), which takes a float32 tensor
and returns a quantized tensor as a uint8 bytes buffer (it’s a type-erased blob
of bytes that can be in any quantization encoding). You can call quantize()
using one of the existing quantization
encodings, such as Q4_0,
Q4_K, and Q6_K (from GGML). Then, add the quantized tensor as a node in your
graph.
Because the quantized data is just a blob of bytes with a special encoding for the values and scaling factor, any op that you pass this data into must know how to dequantize that data in order to perform its calculation with the full-precision float32 value.
Currently, the only op included in MAX Graph that can operate on quantized data
is qmatmul()
. This
takes a float32 tensor and a quantized tensor, and returns a float32 tensor.
This op alone allows you to build a variety of quantized transformer models.
However, using quantized weights with any other op in
max.graph.ops
doesn’t work as is, because
they all expect float32 inputs.
Now let’s look at some simple code examples.
Quantize some weights
When you build a graph with MAX Graph, each batch of weights begins as a
Tensor
that you set in the graph as a
constant (a node created with
Graph.constant()
). When you
want to quantize those weights, just pass the
Tensor
to the quantize()
method from
the encoding type you want to use before you add it to the graph.
For example, the following code quantizes a tensor with
Q4_0Encoding
(4-bit encoding), performs quantized matmul (using
qmatmul()
), and prints the
results:
from max.tensor import Tensor, TensorShape
from max.engine import InferenceSession
from max.graph import Graph, TensorType
from max.graph.quantization import Q4_0Encoding
from max.graph.ops.quantized_ops import qmatmul
def main():
graph = Graph(TensorType(DType.float32, 32, 64))
# Perform matmul with the full-precision constant
# constant_value = Tensor[DType.float32](TensorShape(64, 32), 0.15)
# constant = graph.constant(constant_value)
# matmul = graph[0] @ constant
# Perform matmul with the quantized constant (transposed)
constant_value = Tensor[DType.float32](TensorShape(32, 64), 0.15)
quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
matmul = qmatmul[Q4_0Encoding](graph[0], quantized_constant)
graph.output(matmul)
session = InferenceSession()
model = session.load(graph)
input = Tensor[DType.float32](TensorShape(32, 64), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)
qmatmul()
expects the
"right-hand-side" argument to be transposed. For example, whereas the normal
matmul()
op takes lhs
and rhs
with shapes and ,
respectively (to get an output shape ), qmatmul()
requires the rhs
shape be .
In the above example, the input shape is [32, 64]
and the quantized shape is
also [32, 64]
, making the output shape [32, 32]
.
You probably noticed this code also includes the full-precision matmul as an option. If you toggle the comments on lines 11-13 and 16-19, and run it again, you can see for yourself how close the results are even though the quantized constant uses just 1/8th of the memory (4-bits vs 32-bits).
No matter which quantization encoding you choose, the quantize()
method works
the same—it takes in a full-precision value as a Tensor
value and returns the
quantized value as a Tensor
.
Alternatively, you can use
Graph.quantize()
to combine
these two lines:
quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
Into one line:
quantized_constant = graph.quantize[Q4_0Encoding](constant_value)
To see how we quantized a real model with this API, check out the Quantize TinyStories pipeline, which is a 15-million parameter model quantized with 4-bit encoding down to about 10MB.
Because the Graph API builds a static computation graph, quantization
happens at graph build time. That means you can’t use quantize()
with runtime
inputs, because all the tensors you want to quantize must be fixed at the time
the Graph
calls execute.
Save and load tensors to disk
To avoid quantizing your weights every time you load a model, you can save and
load them from disk using the
save()
and
load()
functions. For
example:
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape
def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.max")
def read_from_disk():
tensors = load("/path/to/checkpoint.max")
x = tensors.get[DType.int32]("x")
The TensorDict
type is just a dictionary type for named tensors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
If you'd like to share more information, please report an issue on GitHub
😔 What went wrong?