Mojo function
scatter_nd_generator
scatter_nd_generator[output_type: DType, indices_type: DType, data_rank: Int, indices_rank: Int, updates_rank: Int, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("cpu"), /, reduce_fn: OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]] = OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]]({:i1 0, 1}), *, _trace_description: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("scatter_nd")](data: NDBuffer[output_type, data_rank, origin], indices: NDBuffer[indices_type, indices_rank, origin], updates: NDBuffer[output_type, updates_rank, origin], output: NDBuffer[output_type, data_rank, origin], context: DeviceContextPtr = DeviceContextPtr())
Implements ONNX ScatterND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND.
Parameters:
- output_type (
DType
): Type of data, updates, and output tensors. - indices_type (
DType
): Type of the indices tensor. - data_rank (
Int
): Rank of input (data) tensor (data_rank >= 1). - indices_rank (
Int
): Rank of input (data) tensor (indices_rank >= 1). - updates_rank (
Int
): Rank of updates tensor (updates_rank = data_rank + indices_rank - indices_shape[-1] - 1). - single_thread_blocking_override (
Bool
): If True, then the operation is run synchronously using a single thread. - target (
StringSlice[StaticConstantOrigin]
): Target cpu or cuda. - reduce_fn (
OptionalReg[fn[DType, Int](SIMD[$0, $1], SIMD[$0, $1]) capturing -> SIMD[$0, $1]]
): Reduction function to apply: none (default), add, mul, max, min. - _trace_description (
StringSlice[StaticConstantOrigin]
): A description of the function, used for profiling and tracing.
Args:
- data (
NDBuffer[output_type, data_rank, origin]
): Tensor of rank data_rank >= 1. - indices (
NDBuffer[indices_type, indices_rank, origin]
): Tensor of rank indices_rank containing indices for the scatter operation. - updates (
NDBuffer[output_type, updates_rank, origin]
): Tensor containing values to update output tensor based on indices tensor. - output (
NDBuffer[output_type, data_rank, origin]
): Tensor of rank data_rank, shaped the same as data tensor. - context (
DeviceContextPtr
): Pointer to DeviceContext.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!