Skip to main content

Mojo function

scatter_nd_generator

scatter_nd_generator[output_type: DType, indices_type: DType, single_thread_blocking_override: Bool, oob_index_strategy: ScatterOobIndexStrategy = ScatterOobIndexStrategy.UNDEFINED, target: StringSlice[StaticConstantOrigin] = "cpu", reduce_fn: OptionalReg[fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, *, _trace_description: StringSlice[StaticConstantOrigin] = "scatter_nd"](data: TileTensor[output_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], indices: TileTensor[indices_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], updates: TileTensor[output_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[output_type, LayoutType, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], context: DeviceContextPtr = DeviceContextPtr())

Implements ONNX ScatterND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND.

Parameters:

  • output_type (DType): Type of data, updates, and output tensors.
  • indices_type (DType): Type of the indices tensor.
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.
  • oob_index_strategy (ScatterOobIndexStrategy): Strategy to handle out of bounds indices.
  • target (StringSlice): Target cpu or cuda.
  • reduce_fn (OptionalReg): Reduction function to apply: none (default), add, mul, max, min.
  • _trace_description (StringSlice): A description of the function, used for profiling and tracing.

Args:

  • data (TileTensor): Tensor of rank data_rank >= 1.
  • indices (TileTensor): Tensor of rank indices_rank containing indices for the scatter operation.
  • updates (TileTensor): Tensor containing values to update output tensor based on indices tensor.
  • output (TileTensor): Tensor of rank data_rank, shaped the same as data tensor.
  • context (DeviceContextPtr): Pointer to DeviceContext.

Was this page helpful?