Indexing and slicing
Model code reads and writes tensor values using four operations, and which one
to use depends on what's known at write time. A fixed position or range, like
the last token in a sequence or a window of attention heads, uses Python slice
syntax directly. Positions computed at runtime, like token IDs from a sampler or
expert assignments from a router, require
F.gather,
which takes a tensor of indices and returns the selected values. When a boolean
condition determines which values to keep,
F.where
selects element-wise from two tensors based on a mask. To write values back to
specific positions,
F.scatter
is the inverse of F.gather.
Extract values by position
Slice syntax works the same way it does in Python, NumPy, and PyTorch: x[i]
selects by index, x[a:b] selects a range, and negative indices count from the
end.
from max.experimental.tensor import Tensor
hidden = Tensor.ones((2, 4, 8)) # (batch=2, seq_len=4, features=8)
print(hidden[0].shape) # drop batch dim → [Dim(4), Dim(8)]
print(hidden[:, 1:3].shape) # window along seq_len → [Dim(2), Dim(2), Dim(8)]
print(hidden[:, -1].shape) # last token → [Dim(2), Dim(8)]
print(hidden[..., -1].shape) # last feature, any rank → [Dim(2), Dim(4)]When the position is computed at runtime, like token IDs from a sampler or
expert assignments from a router, use F.gather.
Select elements by index
When positions are data rather than literals,
F.gather(input, indices, axis)
selects elements from input at the positions given by indices, along axis.
The following example uses it for a vocabulary lookup: each token ID selects one
row from an embedding table:
import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor
embeddings = Tensor.ones((32000, 64))
token_ids = Tensor(
[[1, 42, 7, 100], [5, 13, 99, 200]],
dtype=DType.int64,
)
output = F.gather(embeddings, token_ids, axis=0)
print(output.shape)The expected output is:
[Dim(2), Dim(4), Dim(64)]The output shape replaces the axis dimension of input with the full shape of
indices. Here, input is (vocab_size, embed_dim), indices is (batch, seq_len), and axis=0, so the output is (batch, seq_len, embed_dim): one
embedding vector per token. The same pattern applies to MoE expert assignments,
sampled positions, and dynamic routing indices. When a boolean mask determines
which values to keep, use F.where instead.
Select elements by condition
F.where(condition, x, y)
takes from x where condition is True and from y where it's False,
element-wise. The condition argument must be boolean dtype.
The most common use in model code is applying a causal attention mask.
condition is a lower-triangular boolean matrix that allows each position to
attend only to itself and earlier positions; y is a large negative constant.
After F.where, softmax drives the masked positions toward zero.
The following example applies a causal mask to attention scores of shape
(batch, seq_len, seq_len), setting masked positions to -1e9:
import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor
seq_len = 4
rows = Tensor.arange(seq_len, dtype=DType.int64).reshape((seq_len, 1))
cols = Tensor.arange(seq_len, dtype=DType.int64).reshape((1, seq_len))
mask = rows >= cols
scores = Tensor.ones((2, seq_len, seq_len))
masked = F.where(mask, scores, -1e9)
print(masked.shape)The expected output is:
[Dim(2), Dim(4), Dim(4)]In this example, mask has shape (4, 4) and scores has shape (2, 4, 4).
Because F.where follows the same
broadcasting rules as arithmetic operations, the
mask applies to every sequence in the batch without any reshaping. Positions
above the diagonal receive -1e9; after softmax, their contribution to the
attention weights is negligible. All three arguments to F.where must be
broadcastable to a common shape. When you need to write values back to specific
positions rather than reading from them, F.scatter is the inverse operation.
Place values by index
F.scatter
is the inverse of F.gather: where gather reads values at positions given by
an index tensor, scatter writes values back to those positions. The signature
is F.scatter(input, updates, indices, axis=-1). updates comes before
indices, the opposite of the argument order you might expect. The indices
tensor must have the same shape as updates.
F.scatter currently runs on CPU only.
The following example writes 4 expert outputs back into an 8-token sequence.
F.tile
repeats the route indices across all 64 feature positions so that indices has
the same shape as expert_out:
import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor
hidden = Tensor.zeros((8, 64))
expert_out = Tensor.ones((4, 64))
routes = Tensor([0, 2, 5, 7], dtype=DType.int64).reshape((4, 1))
indices = F.tile(routes, (1, 64))
result = F.scatter(hidden, expert_out, indices, axis=0)
print(result.shape)The expected output is:
[Dim(8), Dim(64)]expert_out is written to rows 0, 2, 5, and 7 of hidden; the remaining rows
keep their original values. The output has the same shape as hidden.
Next steps
Indexing and F.where sit on top of the same shape rules that govern all
elementwise operations. When masks need to align with logits, the rules are the
same ones from broadcasting:
- Review broadcasting for
F.whereand elementwise alignment: Broadcasting - Compare eager execution and explicit graph construction: Eager execution
- Follow an end-to-end model workflow: Model bring-up workflow
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!