Skip to main content
Log in

Mojo function

gather_nd

gather_nd[type: DType, indices_type: DType, data_rank: Int, indices_rank: Int, output_rank: Int, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("cpu"), single_thread_blocking_override: Bool = False](data: NDBuffer[type, data_rank, origin], indices: NDBuffer[indices_type, indices_rank, origin], output: NDBuffer[type, output_rank, origin], ctx: DeviceContextPtr)

GatherND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND. Based on reference implementation: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gathernd.py.

Parameters:

  • type (DType): Type of data tensor.
  • indices_type (DType): Type of indices tensor.
  • data_rank (Int): Rank of data tensor (data_rank >= 1).
  • indices_rank (Int): Rank of indices tensor (indices_rank >= 1).
  • output_rank (Int): Rank of output tensor.
  • batch_dims (Int): Number of batch dimensions. The gather of indexing starts from dimension of data[batch_dims:].
  • target (StringSlice[StaticConstantOrigin]): The target architecture to execute on.
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.

Args:

  • data (NDBuffer[type, data_rank, origin]): Tensor of rank data_rank >= 1.
  • indices (NDBuffer[indices_type, indices_rank, origin]): Tensor of rank indices_rank >= 1. All index values are expected to be within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.
  • output (NDBuffer[type, output_rank, origin]): Tensor of rank data_rank + indices_rank - indices_shape[-1] - 1 - b.
  • ctx (DeviceContextPtr): The DeviceContextPtr as prepared by the graph compiler.

Was this page helpful?