IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

advanced_indexing_getitem

def advanced_indexing_getitem[input_rank: Int, index_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], trace_description: StringSlice[StaticConstantOrigin], input_tensor_fn: def[width: Int](IndexList[input_rank]) capturing -> SIMD[input_type, width], indices_fn: def[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]](out_tensor: TileTensor[input_type, address_space=out_tensor.address_space, linear_idx_type=out_tensor.linear_idx_type, element_size=out_tensor.element_size], in_tensor_strides: IndexList[input_rank], ctx: DeviceContext)

Implement basic numpy-style advanced indexing.

This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics.

This assumes the dimensions in input_tensor not indexed by index tensors are ":", ie selecting all indices along the slice. For example in numpy:

# rank(indices1) == 3
# rank(indices2) == 3
out_tensor = input_tensor[:, :, :, indices1, indices2, :, :]

We calculate the following for all valid valued indexing variables:

out_tensor[a, b, c, i, j, k, d, e] = input_tensor[
    a, b, c,
    indices1[i, j, k],
    indices2[i, j, k],
    d, e
]

In this example start_axis = 3 and num_index_tensors = 2.

TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion)

Parameters:

  • ​input_rank (Int): The rank of the input tensor.
  • ​index_rank (Int): The rank of the indexing tensors.
  • ​input_type (DType): The dtype of the input tensor.
  • ​index_type (DType): The dtype of the indexing tensors.
  • ​start_axis (Int): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions.
  • ​num_index_tensors (Int): The number of indexing tensors.
  • ​target (StringSlice[StaticConstantOrigin]): The target architecture to operation on.
  • ​trace_description (StringSlice[StaticConstantOrigin]): For profiling, the trace name the operation will appear under.
  • ​input_tensor_fn (def[width: Int](IndexList[input_rank]) capturing -> SIMD[input_type, width]): Fusion lambda for the input tensor.
  • ​indices_fn (def[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]): Fusion lambda for the indices tensors.

Args: