IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

advanced_indexing_setitem_inplace

def advanced_indexing_setitem_inplace[index_rank: Int, updates_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], trace_description: StringSlice[StaticConstantOrigin], updates_tensor_fn: def[width: Int](IndexList[updates_rank]) capturing -> SIMD[input_type, width], indices_fn: def[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]](input_tensor: TileTensor[input_type, Storage=input_tensor.Storage, address_space=input_tensor.address_space, linear_idx_type=input_tensor.linear_idx_type, element_size=input_tensor.element_size], index_tensor_shape: IndexList[index_rank], updates_tensor_strides: IndexList[updates_rank], ctx: DeviceContext)

Implement basic numpy-style advanced indexing with assignment.

This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics.

This assumes the dimensions in input_tensor not indexed by index tensors are ":", ie selecting all indices along the slice. For example in numpy:

# rank(indices1) == 2
# rank(indices2) == 2
# rank(updates) == 2
input_tensor[:, :, :, indices1, indices2, :, :] = updates

We calculate the following for all valid valued indexing variables:

input_tensor[
    a, b, c,
    indices1[i, j],
    indices2[i, j],
    d, e
] = updates[i, j]

In this example start_axis = 3 and num_index_tensors = 2.

In terms of implementation details, our strategy is to iterate over all indices over a common iteration range. The idea is we can map indices in this range to the write location in input_tensor as well as the data location in updates. An update can illustrate how this is possible best:

Imagine the input_tensor shape is [A, B, C, D] and we have indexing tensors I1 and I2 with shape [M, N, K]. Assume I1 and I2 are applied to dimensions 1 and 2.

I claim an appropriate common iteration range is then (A, M, N, K, D). Note we expect updates to be the shape [A, M, N, K, D]. We will show this by providing the mappings into updates and input_tensor:

Consider an arbitrary set of indices in this range (a, m, n, k, d): - The index into updates is (a, m, n, k, d). - The index into input_tensor is (a, I1[m, n, k], I2[m, n, k], d).

TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion) TODO(GEX-1954): Unify getitem and setitem using generic views. (Requires non-strided view functions).

Parameters:

  • ​index_rank (Int): The rank of the indexing tensors.
  • ​updates_rank (Int): The rank of the updates tensor.
  • ​input_type (DType): The dtype of the input tensor.
  • ​index_type (DType): The dtype of the indexing tensors.
  • ​start_axis (Int): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions.
  • ​num_index_tensors (Int): The number of indexing tensors.
  • ​target (StringSlice[StaticConstantOrigin]): The target architecture to operation on.
  • ​trace_description (StringSlice[StaticConstantOrigin]): For profiling, the trace name the operation will appear under.
  • ​updates_tensor_fn (def[width: Int](IndexList[updates_rank]) capturing -> SIMD[input_type, width]): Fusion lambda for the update tensor.
  • ​indices_fn (def[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]): Fusion lambda for the indices tensors.

Args: