Skip to main content

Data types (dtype)

Data types (dtypes) define how numbers are stored in tensors. A dtype specifies how each element in a tensor is represented in memory, and every tensor has exactly one dtype that applies to all its elements. Choosing the right dtype affects your model's memory usage, numerical precision, and compatibility with different hardware.

The DType enum in MAX provides all supported data types:

from max.dtype import DType

# DType is an enum that defines how numbers are stored in tensors
# Access dtypes as attributes of the DType class
print(DType.float32)  # 32-bit floating point
print(DType.int32)  # 32-bit integer
print(DType.bool)  # Boolean values

Each dtype has three key characteristics:

  • Precision: How accurately numbers are represented (more bits = more precision).
  • Range: The minimum and maximum values that can be stored.
  • Memory: How many bytes each element requires.

Common dtypes

MAX supports all standard NumPy and PyTorch dtypes:

DTypeSizeDescriptionUse case
DType.bfloat162 bytes16-bit brain float (8 exp, 7 mantissa)ML training, better range than fp16
DType.bool1 byteBoolean true or falseMasks, conditional logic
DType.float162 bytes16-bit IEEE floating pointGPU inference, memory savings
DType.float324 bytes32-bit IEEE floating pointDefault for training and development
DType.int324 bytes32-bit signed integerIndices, counts, discrete values
DType.int648 bytes64-bit signed integerLarge indices, token IDs
DType.int81 byte8-bit signed integerQuantized models, extreme compression

For the complete list including float8 variants and all integer types, see the DType API reference.

Specify dtype when creating tensors

When you create a tensor, you can specify its dtype using the dtype parameter in the format of DType.{dtype_name}:

from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor

# Create a tensor with float32 (default for most operations)
float_tensor = Tensor.ones([2, 3], dtype=DType.float32, device=CPU())
print(f"Float tensor dtype: {float_tensor.dtype}")

# Create a tensor with int32 for indices or counts
int_tensor = Tensor.constant([1, 2, 3], dtype=DType.int32, device=CPU())
print(f"Int tensor dtype: {int_tensor.dtype}")

The expected output is:

Float tensor dtype: DType.float32
Int tensor dtype: DType.int32

In this example, the ones() function creates a tensor filled with ones, and the constant() function creates a tensor filled with the given values. The dtype parameter is used to specify the dtype of the tensor.

If you don't specify a dtype, MAX uses:

  • float32 for CPU devices.
  • bfloat16 for accelerator devices (GPUs).

Check tensor dtype

Every tensor has a dtype property that returns its data type:

from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor

# Create tensors of different types
weights = Tensor.ones([3, 3], dtype=DType.float32, device=CPU())
indices = Tensor.constant([0, 1, 2], dtype=DType.int64, device=CPU())

# Check the dtype of each tensor
print(f"Weights dtype: {weights.dtype}")  # DType.float32
print(f"Indices dtype: {indices.dtype}")  # DType.int64

# Compare dtypes directly
if weights.dtype == DType.float32:
    print("Weights are float32")

The expected output is:

Weights dtype: DType.float32
Indices dtype: DType.int64
Weights are float32

In this example, the weights tensor is a float32 tensor and the indices tensor is a int64 tensor.

Convert between dtypes

The cast() method is used to convert a tensor from one dtype to another. This is useful when you need to convert a tensor from a floating-point type to an integer type, or from a higher precision type to a lower precision type. For example:

from max.driver import CPU
from max.dtype import DType
from max.experimental.tensor import Tensor

# Create a float32 tensor
x = Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32, device=CPU())
print(f"Original dtype: {x.dtype}")  # DType.float32

# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(f"After cast to int32: {y.dtype}")  # DType.int32

# Cast to float64 for higher precision
z = x.cast(DType.float64)
print(f"After cast to float64: {z.dtype}")  # DType.float64

The expected output is:

Original dtype: DType.float32
After cast to int32: DType.int32
After cast to float64: DType.float64

In this example, the original tensor is a float32 tensor, after casting to int32, the tensor is a int32 tensor, and after casting to float64, the tensor is a float64 tensor.

DType properties and methods

The DType enum provides useful properties and methods for inspecting types:

from max.dtype import DType

# Check memory size of different dtypes
print(f"float32 size: {DType.float32.size_in_bytes} bytes")  # 4
print(f"float32.is_float(): {DType.float32.is_float()}")  # True
print(f"int32.is_integral(): {DType.int32.is_integral()}")  # True
print(f"float8_e4m3fn.is_float8(): {DType.float8_e4m3fn.is_float8()}")  # True

The expected output is:

float32 size: 4 bytes
float32.is_float(): True
int32.is_integral(): True
float8_e4m3fn.is_float8(): True

For more information, see the DType API reference.

Interoperability with NumPy and PyTorch tensors

MAX provides seamless dtype conversion with NumPy and PyTorch for working with existing data pipelines.

Use DLPack for tensor conversion

DLPack is a standardized in-memory tensor format and protocol that lets array and tensor libraries share data across devices and frameworks with zero or minimal copies.

The recommended way to convert NumPy arrays to MAX tensors is through DLPack, which enables zero-copy conversion when possible:

import numpy as np
from max.experimental.tensor import Tensor

# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

# Convert to MAX tensor using DLPack (zero-copy when possible)
tensor = Tensor.from_dlpack(np_array)

print(f"NumPy dtype: {np_array.dtype}")  # float32
print(f"MAX tensor dtype: {tensor.dtype}")  # DType.float32
print(f"MAX tensor shape: {tensor.shape}")  # [2, 2]

The expected output is:

NumPy dtype: float32
MAX tensor dtype: DType.float32
MAX tensor shape: [Dim(2), Dim(2)]

In this example, the from_dlpack() method converts the NumPy array to a MAX tensor. You can use this method when converting data from other libraries to MAX.

MAX also provides dtype conversion for PyTorch and NumPy integration. The from_torch() method converts a PyTorch dtype to a MAX dtype. For example:

import torch
from max.dtype import DType

# PyTorch tensor
pt_tensor = torch.randn(10, 10, dtype=torch.float16)

# Convert PyTorch dtype to MAX dtype
# API: DType.from_torch(dtype)
#   dtype: PyTorch dtype
#   Returns: Corresponding MAX DType
#   Raises: ValueError if dtype not supported
#   Raises: RuntimeError if torch not installed
max_dtype = DType.from_torch(pt_tensor.dtype)
print(f"PyTorch {pt_tensor.dtype} → MAX {max_dtype}")  # float16 → DType.float16

The expected output is:

PyTorch torch.float16 → MAX DType.float16

Other conversion functions you can use are:

  • to_numpy(): Convert a MAX dtype to a NumPy dtype.
  • to_torch(): Convert a MAX dtype to a PyTorch dtype.

Memory optimization

Understanding dtype memory usage is critical for deploying large models. The size_in_bytes property lets you calculate exact memory requirements.

from max.dtype import DType


def calculate_memory(shape: list[int], dtype: DType) -> int:
    """Calculate memory usage in bytes for a tensor."""
    # API: dtype.size_in_bytes
    #   Returns: Size of dtype in bytes (int)
    num_elements = 1
    for dim in shape:
        num_elements *= dim

    bytes_used = num_elements * dtype.size_in_bytes
    return bytes_used


# Compare dtypes for same tensor
shape = [1024, 1024, 1024]  # 1B elements

float32_mb = calculate_memory(shape, DType.float32) / (1024**2)
float16_mb = calculate_memory(shape, DType.float16) / (1024**2)
int8_mb = calculate_memory(shape, DType.int8) / (1024**2)

print(f"float32: {float32_mb:.1f} MB")  # 4096.0 MB
print(f"float16: {float16_mb:.1f} MB")  # 2048.0 MB (50% reduction)
print(f"int8: {int8_mb:.1f} MB")  # 1024.0 MB (75% reduction)

Type validation

Use dtype checking methods to write code that validates inputs at runtime. For example:

from max.dtype import DType


def validate_weights_dtype(dtype: DType) -> None:
    """Ensure weights use a floating-point type."""
    # API: dtype.is_float()
    #   Returns: True if dtype is any floating-point type
    if not dtype.is_float():
        raise TypeError(f"Weights must be float type, got {dtype}")


def validate_indices_dtype(dtype: DType) -> None:
    """Ensure indices use an integer type."""
    # API: dtype.is_integral()
    #   Returns: True if dtype is any integer type (signed or unsigned)
    if not dtype.is_integral():
        raise TypeError(f"Indices must be integer type, got {dtype}")


# Usage
weights_dtype = DType.float16
indices_dtype = DType.int32

validate_weights_dtype(weights_dtype)  # OK
validate_indices_dtype(indices_dtype)  # OK

Next steps

Now that you understand dtypes, continue learning:

  • Building graphs: Specify dtypes in computation graphs.
  • Quantization: Quantize weights to reduce memory usage and improve performance.