Data types (dtype)
Data types (dtypes) define how numbers are stored in tensors. A dtype specifies how each element in a tensor is represented in memory, and every tensor has exactly one dtype that applies to all its elements. Choosing the right dtype affects your model's memory usage, numerical precision, and compatibility with different hardware.
The DType enum in MAX provides all supported data types:
from max.dtype import DType
# DType is an enum that defines how numbers are stored in tensors
# Access dtypes as attributes of the DType class
print(DType.float32) # 32-bit floating point
print(DType.int32) # 32-bit integer
print(DType.bool) # Boolean valuesEach dtype has three key characteristics:
- Precision: How accurately numbers are represented (more bits = more precision).
- Range: The minimum and maximum values that can be stored.
- Memory: How many bytes each element requires.
Common dtypes
MAX supports all standard NumPy and PyTorch dtypes:
| DType | Size | Description | Use case |
|---|---|---|---|
DType.bfloat16 | 2 bytes | 16-bit brain float (8 exp, 7 mantissa) | ML training, better range than fp16 |
DType.bool | 1 byte | Boolean true or false | Masks, conditional logic |
DType.float16 | 2 bytes | 16-bit IEEE floating point | GPU inference, memory savings |
DType.float32 | 4 bytes | 32-bit IEEE floating point | Default for training and development |
DType.int32 | 4 bytes | 32-bit signed integer | Indices, counts, discrete values |
DType.int64 | 8 bytes | 64-bit signed integer | Large indices, token IDs |
DType.int8 | 1 byte | 8-bit signed integer | Quantized models, extreme compression |
For the complete list including float8 variants and all integer types, see the DType API reference.
Specify dtype when creating tensors
When you create a tensor, you can specify its dtype using the dtype parameter in the format of DType.{dtype_name}:
from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor
# Create a tensor with float32 (default for most operations)
float_tensor = Tensor.ones([2, 3], dtype=DType.float32, device=CPU())
print(f"Float tensor dtype: {float_tensor.dtype}")
# Create a tensor with int32 for indices or counts
int_tensor = Tensor.constant([1, 2, 3], dtype=DType.int32, device=CPU())
print(f"Int tensor dtype: {int_tensor.dtype}")The expected output is:
Float tensor dtype: DType.float32
Int tensor dtype: DType.int32In this example, the ones() function creates a tensor filled with ones, and
the constant() function creates a tensor filled with the given values. The
dtype parameter is used to specify the dtype of the tensor.
If you don't specify a dtype, MAX uses:
float32for CPU devices.bfloat16for accelerator devices (GPUs).
Check tensor dtype
Every tensor has a dtype property that returns its data type:
from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor
# Create tensors of different types
weights = Tensor.ones([3, 3], dtype=DType.float32, device=CPU())
indices = Tensor.constant([0, 1, 2], dtype=DType.int64, device=CPU())
# Check the dtype of each tensor
print(f"Weights dtype: {weights.dtype}") # DType.float32
print(f"Indices dtype: {indices.dtype}") # DType.int64
# Compare dtypes directly
if weights.dtype == DType.float32:
print("Weights are float32")The expected output is:
Weights dtype: DType.float32
Indices dtype: DType.int64
Weights are float32In this example, the weights tensor is a float32 tensor and the indices tensor
is a int64 tensor.
Convert between dtypes
The cast() method is used to convert a tensor from one dtype to another. This
is useful when you need to convert a tensor from a floating-point type to an
integer type, or from a higher precision type to a lower precision type. For
example:
from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor
# Create a float32 tensor
x = Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32, device=CPU())
print(f"Original dtype: {x.dtype}") # DType.float32
# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(f"After cast to int32: {y.dtype}") # DType.int32
# Cast to float64 for higher precision
z = x.cast(DType.float64)
print(f"After cast to float64: {z.dtype}") # DType.float64The expected output is:
Original dtype: DType.float32
After cast to int32: DType.int32
After cast to float64: DType.float64In this example, the original tensor is a float32 tensor, after casting to
int32, the tensor is a int32 tensor, and after casting to float64, the
tensor is a float64 tensor.
DType properties and methods
The DType enum provides useful properties and methods for inspecting types:
from max.dtype import DType
# Check memory size of different dtypes
print(f"float32 size: {DType.float32.size_in_bytes} bytes") # 4
print(f"float32.is_float(): {DType.float32.is_float()}") # True
print(f"int32.is_integral(): {DType.int32.is_integral()}") # True
print(f"float8_e4m3fn.is_float8(): {DType.float8_e4m3fn.is_float8()}") # TrueThe expected output is:
float32 size: 4 bytes
float32.is_float(): True
int32.is_integral(): True
float8_e4m3fn.is_float8(): TrueFor more information, see the DType API reference.
Interoperability with NumPy and PyTorch tensors
MAX provides seamless dtype conversion with NumPy and PyTorch for working with existing data pipelines.
Use DLPack for tensor conversion
DLPack is a standardized in-memory tensor format and protocol that lets array and tensor libraries share data across devices and frameworks with zero or minimal copies.
The recommended way to convert NumPy arrays to MAX tensors is through DLPack, which enables zero-copy conversion when possible:
import numpy as np
from max.experimental.tensor import Tensor
# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
# Convert to MAX tensor using DLPack (zero-copy when possible)
tensor = Tensor.from_dlpack(np_array)
print(f"NumPy dtype: {np_array.dtype}") # float32
print(f"MAX tensor dtype: {tensor.dtype}") # DType.float32
print(f"MAX tensor shape: {tensor.shape}") # [2, 2]The expected output is:
NumPy dtype: float32
MAX tensor dtype: DType.float32
MAX tensor shape: [Dim(2), Dim(2)]In this example, the from_dlpack() method converts the NumPy array to a MAX tensor.
You can use this method when converting data from other libraries to MAX.
MAX also provides dtype conversion for PyTorch and NumPy integration.
The from_torch() method converts a PyTorch dtype to a MAX dtype. For example:
import torch
from max.dtype import DType
# PyTorch tensor
pt_tensor = torch.randn(10, 10, dtype=torch.float16)
# Convert PyTorch dtype to MAX dtype
# API: DType.from_torch(dtype)
# dtype: PyTorch dtype
# Returns: Corresponding MAX DType
# Raises: ValueError if dtype not supported
# Raises: RuntimeError if torch not installed
max_dtype = DType.from_torch(pt_tensor.dtype)
print(f"PyTorch {pt_tensor.dtype} → MAX {max_dtype}") # float16 → DType.float16The expected output is:
PyTorch torch.float16 → MAX DType.float16Other conversion functions you can use are:
to_numpy(): Convert a MAX dtype to a NumPy dtype.to_torch(): Convert a MAX dtype to a PyTorch dtype.
Memory optimization
Understanding dtype memory usage is critical for deploying large models. The
size_in_bytes property lets you calculate exact memory requirements.
from max.dtype import DType
def calculate_memory(shape: list[int], dtype: DType) -> int:
"""Calculate memory usage in bytes for a tensor."""
# API: dtype.size_in_bytes
# Returns: Size of dtype in bytes (int)
num_elements = 1
for dim in shape:
num_elements *= dim
bytes_used = num_elements * dtype.size_in_bytes
return bytes_used
# Compare dtypes for same tensor
shape = [1024, 1024, 1024] # 1B elements
float32_mb = calculate_memory(shape, DType.float32) / (1024**2)
float16_mb = calculate_memory(shape, DType.float16) / (1024**2)
int8_mb = calculate_memory(shape, DType.int8) / (1024**2)
print(f"float32: {float32_mb:.1f} MB") # 4096.0 MB
print(f"float16: {float16_mb:.1f} MB") # 2048.0 MB (50% reduction)
print(f"int8: {int8_mb:.1f} MB") # 1024.0 MB (75% reduction)Type validation
Use dtype checking methods to write code that validates inputs at runtime. For example:
from max.dtype import DType
def validate_weights_dtype(dtype: DType) -> None:
"""Ensure weights use a floating-point type."""
# API: dtype.is_float()
# Returns: True if dtype is any floating-point type
if not dtype.is_float():
raise TypeError(f"Weights must be float type, got {dtype}")
def validate_indices_dtype(dtype: DType) -> None:
"""Ensure indices use an integer type."""
# API: dtype.is_integral()
# Returns: True if dtype is any integer type (signed or unsigned)
if not dtype.is_integral():
raise TypeError(f"Indices must be integer type, got {dtype}")
# Usage
weights_dtype = DType.float16
indices_dtype = DType.int32
validate_weights_dtype(weights_dtype) # OK
validate_indices_dtype(indices_dtype) # OKNext steps
Now that you understand dtypes, continue learning:
- Building graphs: Specify dtypes in computation graphs.
- Quantization: Quantize weights to reduce memory usage and improve performance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!