Broadcasting
Broadcasting is how MAX combines tensors of mismatched shapes in elementwise
operations. A (features,) bias added to a (batch, features) activation
matrix, a (seq_len, seq_len) mask applied across a (batch, seq_len, seq_len)
score tensor: these shapes differ in rank, yet both operations succeed. MAX
doesn't require exact shape matches; it requires shapes to satisfy a
compatibility rule that determines when dimensions can expand. Understanding
that rule tells you which combinations work, which raise an error, and how to
fix the ones that don't.
Understand broadcasting rules
Two tensor shapes are compatible for broadcasting when, aligned from the right, every pair of corresponding dimensions either matches or at least one of the two is a 1. Shapes with different numbers of dimensions are padded on the left with 1s before alignment.
Three cases cover all broadcasting behavior:
-
Align from the right:
(3, 4)and(4,)become(3, 4)and(1, 4)after padding. The trailing dimensions match (both 4), and the padded dimension is 1, so MAX expands it across 3 rows. Output shape:(3, 4).3 4 4 ← padded to (1, 4) ------ 3 4 ✓ -
Both can expand:
(3, 1)and(1, 4)are already the same rank. Each has a 1 where the other has a non-1 value. MAX expands each along its 1-dimension. Output shape:(3, 4).3 1 1 4 ------ 3 4 ✓ -
Incompatible shapes:
(3, 4)and(3, 5)align to4vs.5at the trailing position. Neither is 1 and they don't match, so broadcasting fails.3 4 3 5 ------ X (4 != 5, neither is 1)
The rule reduces to: after right-alignment, every column must have a 1 in at least one tensor or equal values in both. When that condition holds, MAX handles the expansion automatically. The next section shows what that looks like with concrete shapes.
Broadcast tensors automatically
MAX applies broadcasting automatically in elementwise operations: addition,
subtraction, multiplication, and division. You write the operation the same way
you would for matching shapes: activations + bias. MAX detects the shape
difference and expands the smaller tensor before operating, with no separate API
call needed.
Bias vector broadcast
A bias vector has shape (features,), but the activations it needs to shift
have shape (batch, features). The ranks differ, yet the addition works without
any reshape: MAX pads bias on the left to (1, features) and broadcasts it
across every row in the batch. The following example adds an (8,) bias to
(4, 8) activations:
from max.experimental.tensor import Tensor
activations = Tensor.zeros([4, 8]) # (batch=4, features=8)
bias = Tensor.ones([8]) # (features=8,)
result = activations + bias
print(result.shape)The expected output is:
[Dim(4), Dim(8)]MAX pads bias on the left to (1, 8), then broadcasts it across 4 rows. The
output shape is (4, 8).
Attention mask broadcast
An attention mask has shape (seq_len, seq_len), but the score tensor it
applies to has shape (batch, seq_len, seq_len). The mask is missing the batch
dimension, yet the operation works without any reshape: MAX pads the leading
dimension automatically, applying the same mask to every sequence in the batch.
The following example adds an (8, 8) mask to (2, 8, 8) scores:
from max.experimental.tensor import Tensor
mask = Tensor.zeros([8, 8]) # (seq_len=8, seq_len=8)
scores = Tensor.zeros([2, 8, 8]) # (batch=2, seq_len, seq_len)
result = scores + mask
print(result.shape)The expected output is:
[Dim(2), Dim(8), Dim(8)]MAX pads mask on the left to (1, 8, 8) before expanding across 2 batches,
with no reshape needed. When you need to control the target shape explicitly
rather than let the operation infer it, F.broadcast_to provides that control.
Broadcast tensors explicitly
When you need the target shape explicit in code rather than inferred from an
operation,
F.broadcast_to(x, shape)
expands a tensor directly to a target shape without running an elementwise
operation.
The following example uses F.broadcast_to to expand a (1, features) weight
vector to (batch, features):
from max.experimental.tensor import Tensor
from max.experimental import functional as F
weight = Tensor.ones([1, 8]) # (1, features=8)
expanded = F.broadcast_to(weight, [4, 8]) # (batch=4, features=8)
print(expanded.shape)The expected output is:
[Dim(4), Dim(8)]In this example, F.broadcast_to produces a (4, 8) tensor with weight
values repeated across 4 rows. The target shape must satisfy the compatibility
rule: every dimension in weight must be either 1 or equal to the corresponding
target dimension.
When the shapes don't satisfy the rule, MAX raises an error as soon as the expression is evaluated. The next section covers how to read and fix those errors.
Debug shape mismatch errors
Shape mismatch errors in MAX fire when the graph operation is created, as soon
as the incompatible expression is evaluated. The ValueError includes the
shapes of both inputs and names the axis where they conflict. Two causes
account for most of these errors: misaligned ranks and mismatched non-1
dimensions.
Misaligned ranks
Adding a (features,) vector to a (seq_len, batch, features) tensor fails
because right-alignment places features against features, but the
(seq_len, batch) dimensions have no counterpart in the vector. The solution is
to introduce the missing leading dimensions with
Tensor.unsqueeze(axis)
before the operation. The following example adds two leading dimensions to a
(features,) vector so it aligns with a (seq_len, batch, features) tensor:
from max.experimental.tensor import Tensor
v = Tensor.ones([8]) # (features=8,)
v2d = v.unsqueeze(0) # (1, features=8)
v3d = v2d.unsqueeze(0) # (1, 1, features=8)
x = Tensor.zeros([2, 4, 8]) # (seq_len=2, batch=4, features=8)
result = x + v3d
print(result.shape)The expected output is:
[Dim(2), Dim(4), Dim(8)]In this example, two .unsqueeze() calls add the missing dimensions so
right-alignment produces three compatible pairs.
Mismatched non-1 dimensions
Adding a (3, 4) tensor to a (3, 5) tensor fails because 4 and 5 are
neither equal nor 1. The error names the axis:
ValueError: Failed to create op 'add':
...
error: Input lhs (shape [3, 4]) dimension at axis 1 (value 4) and input rhs
(shape [3, 5]) dimension at axis 1 (value 5) are neither equivalent nor
broadcastable.There's no broadcast fix: the shapes must match. Print each tensor's shape before the operation to confirm where they diverged:
print(a.shape, b.shape) # verify before operatingShape errors that appear late in a pipeline often trace to a wrong hidden size or sequence length set earlier. Print shapes at each step to find where they diverged. When both shapes and ranks look correct but the error persists, the mismatch may be in a dimension added by a previous reshape: confirm the full shape at the point of failure.
With the broadcasting rule and these debugging techniques, you can predict compatibility for any two shapes and fix mismatches without guessing.
Next steps
Once shapes are predictable for elementwise ops, the next Fundamentals page
covers how to extract and combine values by position, runtime indices, or
boolean masks (F.where relies on the broadcasting rules from this page):
- Slice, gather, mask, and scatter tensor values: Indexing and slicing
- Compile ops and see how shapes carry into graphs: Get started with MAX graphs
- Review elementwise APIs: Basic operations
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!