Mojo function
gather_nd
gather_nd[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](data: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)
GatherND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND. Based on reference implementation: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gathernd.py.
Parameters:
- dtype (
DType
): Type of data tensor. - indices_type (
DType
): Type of indices tensor. - batch_dims (
Int
): Number of batch dimensions. The gather of indexing starts from dimension of data[batch_dims:]. - target (
StringSlice
): The target architecture to execute on. - single_thread_blocking_override (
Bool
): If True, then the operation is run synchronously using a single thread.
Args:
- data (
LayoutTensor
): Tensor of rank data_rank >= 1. - indices (
LayoutTensor
): Tensor of rank indices_rank >= 1. All index values are expected to be within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds. - output (
LayoutTensor
): Tensor of rank data_rank + indices_rank - indices_shape[-1] - 1 - b. - ctx (
DeviceContextPtr
): The DeviceContextPtr as prepared by the graph compiler.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!