Skip to main content

Mojo function

non_max_suppression

non_max_suppression[dtype: DType](boxes: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], scores: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], output: TileTensor[DType.int64, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], max_output_boxes_per_class: Int, iou_threshold: Float32, score_threshold: Float32)

Perform Non-Maximum Suppression (NMS) on bounding boxes.

This is a buffer semantic overload that writes results directly to an output tensor. NMS iteratively selects boxes with highest scores while suppressing nearby boxes with high overlap (IoU).

Parameters:

  • dtype (DType): The data type for box coordinates and scores.

Args:

  • boxes (TileTensor): Rank-3 tensor of bounding boxes with shape (batch, num_boxes, 4). Each box is [y1, x1, y2, x2].
  • scores (TileTensor): Rank-3 tensor of scores with shape (batch, num_classes, num_boxes).
  • output (TileTensor): Rank-2 output tensor to store selected boxes as (N, 3) where each row is [batch_idx, class_idx, box_idx].
  • max_output_boxes_per_class (Int): Maximum number of boxes to select per class.
  • iou_threshold (Float32): IoU threshold for suppression. Boxes with IoU > threshold are suppressed.
  • score_threshold (Float32): Minimum score threshold. Boxes with score < threshold are filtered out.

non_max_suppression[dtype: DType, func: fn(Int64, Int64, Int64) capturing -> None](boxes: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], scores: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], max_output_boxes_per_class: Int, iou_threshold: Float32, score_threshold: Float32)

Implements the NonMaxSuppression operator from the ONNX spec https://github.com/onnx/onnx/blob/main/docs/Operators.md#nonmaxsuppression.

Was this page helpful?