Mojo function
scatter_nd_generator
scatter_nd_generator[output_type: DType, indices_type: DType, //, oob_index_strategy: ScatterOobIndexStrategy = ScatterOobIndexStrategy.UNDEFINED, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu"), reduce_fn: OptionalReg[def[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, *, _trace_description: StringSlice[StaticConstantOrigin] = StringSlice("scatter_nd")](data: TileTensor[output_type, linear_idx_type=data.linear_idx_type, element_size=data.element_size], indices: TileTensor[indices_type, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], updates: TileTensor[output_type, linear_idx_type=updates.linear_idx_type, element_size=updates.element_size], output: TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: DeviceContextPtr = DeviceContextPtr())
Implements ONNX ScatterND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND.
Parameters:
- βoutput_type (
DType): Type of data, updates, and output tensors. - βindices_type (
DType): Type of the indices tensor. - βoob_index_strategy (
ScatterOobIndexStrategy): Strategy to handle out of bounds indices. - βtarget (
StringSlice[StaticConstantOrigin]): Target cpu or cuda. - βreduce_fn (
OptionalReg[def[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Reduction function to apply: none (default), add, mul, max, min. - β_trace_description (
StringSlice[StaticConstantOrigin]): A description of the function, used for profiling and tracing.
Args:
- βdata (
TileTensor[output_type, linear_idx_type=data.linear_idx_type, element_size=data.element_size]): Tensor of rank data_rank >= 1. - βindices (
TileTensor[indices_type, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size]): Tensor of rank indices_rank containing indices for the scatter operation. - βupdates (
TileTensor[output_type, linear_idx_type=updates.linear_idx_type, element_size=updates.element_size]): Tensor containing values to update output tensor based on indices tensor. - βoutput (
TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Tensor of rank data_rank, shaped the same as data tensor. - βcontext (
DeviceContextPtr): Pointer to DeviceContext.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!