IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

gather_nd

def gather_nd[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](data: TileTensor[dtype, address_space=data.address_space, linear_idx_type=data.linear_idx_type, element_size=data.element_size], indices: TileTensor[indices_type, address_space=indices.address_space, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], ctx: DeviceContext)

GatherND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND. Based on reference implementation: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gathernd.py.

Parameters:

  • ​dtype (DType): Type of data tensor.
  • ​indices_type (DType): Type of indices tensor.
  • ​batch_dims (Int): Number of batch dimensions. The gather of indexing starts from dimension of data[batch_dims:].
  • ​target (StringSlice[StaticConstantOrigin]): The target architecture to execute on.

Args: