Skip to main content

Indexing and slicing

Model code reads and writes tensor values using four operations, and which one to use depends on what's known at write time. A fixed position or range, like the last token in a sequence or a window of attention heads, uses Python slice syntax directly. Positions computed at runtime, like token IDs from a sampler or expert assignments from a router, require F.gather, which takes a tensor of indices and returns the selected values. When a boolean condition determines which values to keep, F.where selects element-wise from two tensors based on a mask. To write values back to specific positions, F.scatter is the inverse of F.gather.

Extract values by position

Slice syntax works the same way it does in Python, NumPy, and PyTorch: x[i] selects by index, x[a:b] selects a range, and negative indices count from the end.

from max.experimental.tensor import Tensor

hidden = Tensor.ones((2, 4, 8))  # (batch=2, seq_len=4, features=8)

print(hidden[0].shape)        # drop batch dim         → [Dim(4), Dim(8)]
print(hidden[:, 1:3].shape)   # window along seq_len   → [Dim(2), Dim(2), Dim(8)]
print(hidden[:, -1].shape)    # last token             → [Dim(2), Dim(8)]
print(hidden[..., -1].shape)  # last feature, any rank → [Dim(2), Dim(4)]

When the position is computed at runtime, like token IDs from a sampler or expert assignments from a router, use F.gather.

Select elements by index

When positions are data rather than literals, F.gather(input, indices, axis) selects elements from input at the positions given by indices, along axis. The following example uses it for a vocabulary lookup: each token ID selects one row from an embedding table:

import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor

embeddings = Tensor.ones((32000, 64))
token_ids = Tensor(
    [[1, 42, 7, 100], [5, 13, 99, 200]],
    dtype=DType.int64,
)
output = F.gather(embeddings, token_ids, axis=0)
print(output.shape)

The expected output is:

[Dim(2), Dim(4), Dim(64)]

The output shape replaces the axis dimension of input with the full shape of indices. Here, input is (vocab_size, embed_dim), indices is (batch, seq_len), and axis=0, so the output is (batch, seq_len, embed_dim): one embedding vector per token. The same pattern applies to MoE expert assignments, sampled positions, and dynamic routing indices. When a boolean mask determines which values to keep, use F.where instead.

Select elements by condition

F.where(condition, x, y) takes from x where condition is True and from y where it's False, element-wise. The condition argument must be boolean dtype.

The most common use in model code is applying a causal attention mask. condition is a lower-triangular boolean matrix that allows each position to attend only to itself and earlier positions; y is a large negative constant. After F.where, softmax drives the masked positions toward zero.

The following example applies a causal mask to attention scores of shape (batch, seq_len, seq_len), setting masked positions to -1e9:

import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor

seq_len = 4

rows = Tensor.arange(seq_len, dtype=DType.int64).reshape((seq_len, 1))
cols = Tensor.arange(seq_len, dtype=DType.int64).reshape((1, seq_len))
mask = rows >= cols
scores = Tensor.ones((2, seq_len, seq_len))

masked = F.where(mask, scores, -1e9)
print(masked.shape)

The expected output is:

[Dim(2), Dim(4), Dim(4)]

In this example, mask has shape (4, 4) and scores has shape (2, 4, 4). Because F.where follows the same broadcasting rules as arithmetic operations, the mask applies to every sequence in the batch without any reshaping. Positions above the diagonal receive -1e9; after softmax, their contribution to the attention weights is negligible. All three arguments to F.where must be broadcastable to a common shape. When you need to write values back to specific positions rather than reading from them, F.scatter is the inverse operation.

Place values by index

F.scatter is the inverse of F.gather: where gather reads values at positions given by an index tensor, scatter writes values back to those positions. The signature is F.scatter(input, updates, indices, axis=-1). updates comes before indices, the opposite of the argument order you might expect. The indices tensor must have the same shape as updates.

F.scatter currently runs on CPU only.

The following example writes 4 expert outputs back into an 8-token sequence. F.tile repeats the route indices across all 64 feature positions so that indices has the same shape as expert_out:

import max.experimental.functional as F
from max.dtype import DType
from max.experimental.tensor import Tensor

hidden = Tensor.zeros((8, 64))
expert_out = Tensor.ones((4, 64))

routes = Tensor([0, 2, 5, 7], dtype=DType.int64).reshape((4, 1))
indices = F.tile(routes, (1, 64))

result = F.scatter(hidden, expert_out, indices, axis=0)
print(result.shape)

The expected output is:

[Dim(8), Dim(64)]

expert_out is written to rows 0, 2, 5, and 7 of hidden; the remaining rows keep their original values. The output has the same shape as hidden.

Next steps

Indexing and F.where sit on top of the same shape rules that govern all elementwise operations. When masks need to align with logits, the rules are the same ones from broadcasting:

Was this page helpful?