Mojo function
scatter_nd_generator
scatter_nd_generator[output_type: DType, indices_type: DType, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu", /, reduce_fn: OptionalReg[fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, *, _trace_description: StringSlice[StaticConstantOrigin] = "scatter_nd"](data: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], updates: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())
Implements ONNX ScatterND operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND.
Parameters:
- output_type (
DType
): Type of data, updates, and output tensors. - indices_type (
DType
): Type of the indices tensor. - single_thread_blocking_override (
Bool
): If True, then the operation is run synchronously using a single thread. - target (
StringSlice
): Target cpu or cuda. - reduce_fn (
OptionalReg
): Reduction function to apply: none (default), add, mul, max, min. - _trace_description (
StringSlice
): A description of the function, used for profiling and tracing.
Args:
- data (
LayoutTensor
): Tensor of rank data_rank >= 1. - indices (
LayoutTensor
): Tensor of rank indices_rank containing indices for the scatter operation. - updates (
LayoutTensor
): Tensor containing values to update output tensor based on indices tensor. - output (
LayoutTensor
): Tensor of rank data_rank, shaped the same as data tensor. - context (
DeviceContextPtr
): Pointer to DeviceContext.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!