Skip to main content

Mojo function

index_tensor

index_tensor[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](data: TileTensor[dtype, data.LayoutType, data.origin, address_space=data.address_space, linear_idx_type=data.linear_idx_type, element_shape_types=data.element_shape_types], indices: TileTensor[indices_type, indices.LayoutType, indices.origin, address_space=indices.address_space, linear_idx_type=indices.linear_idx_type, element_shape_types=indices.element_shape_types], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_shape_types=output.element_shape_types], ctx: DeviceContextPtr)

Index_tensor operation; based on modified implementation of gather_nd.

Parameters:

  • dtype (DType): Type of data tensor.
  • indices_type (DType): Type of indices tensor.
  • batch_dims (Int): Number of batch dimensions. The gather of indexing starts from dimension of data[batch_dims:].
  • target (StringSlice): The target architecture to execute on.
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.

Args:

  • data (TileTensor): Tensor of rank data_rank >= 1.
  • indices (TileTensor): Tensor of rank indices_rank >= 1. All index values are expected to be within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.
  • output (TileTensor): Tensor of rank data_rank + indices_rank - indices_shape[-1] - 1 - b.
  • ctx (DeviceContextPtr): The DeviceContextPtr as prepared by the graph compiler.

Was this page helpful?