Mojo function
gather_nd
gather_nd[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](data: TileTensor[dtype, address_space=data.address_space, linear_idx_type=data.linear_idx_type, element_size=data.element_size], indices: TileTensor[indices_type, address_space=indices.address_space, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], ctx: DeviceContextPtr)
GatherND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND. Based on reference implementation: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gathernd.py.
Parameters:
- βdtype (
DType): Type of data tensor. - βindices_type (
DType): Type of indices tensor. - βbatch_dims (
Int): Number of batch dimensions. The gather of indexing starts from dimension of data[batch_dims:]. - βtarget (
StringSlice[StaticConstantOrigin]): The target architecture to execute on.
Args:
- βdata (
TileTensor[dtype, address_space=data.address_space, linear_idx_type=data.linear_idx_type, element_size=data.element_size]): Tensor of rank data_rank >= 1. - βindices (
TileTensor[indices_type, address_space=indices.address_space, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size]): Tensor of rank indices_rank >= 1. All index values are expected to be within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds. - βoutput (
TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Tensor of rank data_rank + indices_rank - indices_shape[-1] - 1 - b. - βctx (
DeviceContextPtr): The DeviceContextPtr as prepared by the graph compiler.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!