Skip to main content
Log in

Mojo function

scatter_nd_generator

scatter_nd_generator[output_type: DType, indices_type: DType, data_rank: Int, indices_rank: Int, updates_rank: Int, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("cpu"), /, reduce_fn: OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]] = OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]]({:i1 0, 1}), *, _trace_description: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("scatter_nd")](data: NDBuffer[output_type, data_rank, origin], indices: NDBuffer[indices_type, indices_rank, origin], updates: NDBuffer[output_type, updates_rank, origin], output: NDBuffer[output_type, data_rank, origin], context: DeviceContextPtr = DeviceContextPtr())

Implements ONNX ScatterND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND.

Parameters:

  • output_type (DType): Type of data, updates, and output tensors.
  • indices_type (DType): Type of the indices tensor.
  • data_rank (Int): Rank of input (data) tensor (data_rank >= 1).
  • indices_rank (Int): Rank of input (data) tensor (indices_rank >= 1).
  • updates_rank (Int): Rank of updates tensor (updates_rank = data_rank + indices_rank - indices_shape[-1] - 1).
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.
  • target (StringSlice[StaticConstantOrigin]): Target cpu or cuda.
  • reduce_fn (OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]]): Reduction function to apply: none (default), add, mul, max, min.
  • _trace_description (StringSlice[StaticConstantOrigin]): A description of the function, used for profiling and tracing.

Args:

  • data (NDBuffer[output_type, data_rank, origin]): Tensor of rank data_rank >= 1.
  • indices (NDBuffer[indices_type, indices_rank, origin]): Tensor of rank indices_rank containing indices for the scatter operation.
  • updates (NDBuffer[output_type, updates_rank, origin]): Tensor containing values to update output tensor based on indices tensor.
  • output (NDBuffer[output_type, data_rank, origin]): Tensor of rank data_rank, shaped the same as data tensor.
  • context (DeviceContextPtr): Pointer to DeviceContext.

Was this page helpful?