Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for IOS Matching Metric. Introduced the mask_non_max_merge function for handling non-maximum merging of masks #1774

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from supervision.detection.overlap_filter import (
box_non_max_merge,
box_non_max_suppression,
mask_non_max_merge,
mask_non_max_suppression,
)
from supervision.detection.tools.transformers import (
Expand All @@ -32,6 +33,7 @@
get_data_item,
is_data_equal,
is_metadata_equal,
mask_iou_batch,
mask_to_xyxy,
merge_data,
merge_metadata,
Expand Down Expand Up @@ -1281,7 +1283,10 @@ def box_area(self) -> np.ndarray:
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])

def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
self,
threshold: float = 0.5,
class_agnostic: bool = False,
match_metric: str = "IOU",
) -> Detections:
"""
Performs non-max suppression on detection set. If the detections result
Expand All @@ -1294,6 +1299,8 @@ def with_nms(
class_agnostic (bool): Whether to perform class-agnostic
non-maximum suppression. If True, the class_id of each detection
will be ignored. Defaults to False.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
Detections: A new Detections object containing the subset of detections
Expand Down Expand Up @@ -1327,17 +1334,25 @@ def with_nms(

if self.mask is not None:
indices = mask_non_max_suppression(
predictions=predictions, masks=self.mask, iou_threshold=threshold
predictions=predictions,
masks=self.mask,
iou_threshold=threshold,
match_metric=match_metric,
)
else:
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=threshold
predictions=predictions,
iou_threshold=threshold,
match_metric=match_metric,
)

return self[indices]

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
self,
threshold: float = 0.5,
class_agnostic: bool = False,
match_metric: str = "IOU",
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.
Expand All @@ -1348,6 +1363,8 @@ def with_nmm(
class_agnostic (bool): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
Detections: A new Detections object containing the subset of detections
Expand Down Expand Up @@ -1381,15 +1398,25 @@ def with_nmm(
)
)

merge_groups = box_non_max_merge(
predictions=predictions, iou_threshold=threshold
)
if self.mask is not None:
merge_groups = mask_non_max_merge(
predictions=predictions,
masks=self.mask,
iou_threshold=threshold,
match_metric=match_metric,
)
else:
merge_groups = box_non_max_merge(
predictions=predictions,
iou_threshold=threshold,
match_metric=match_metric,
)

result = []
for merge_group in merge_groups:
unmerged_detections = [self[i] for i in merge_group]
merged_detections = merge_inner_detections_objects(
unmerged_detections, threshold
unmerged_detections, threshold, match_metric
)
result.append(merged_detections)

Expand Down Expand Up @@ -1489,7 +1516,7 @@ def merge_inner_detection_object_pair(


def merge_inner_detections_objects(
detections: List[Detections], threshold=0.5
detections: List[Detections], threshold=0.5, match_metric: str = "IOU"
) -> Detections:
"""
Given N detections each of length 1 (exactly one object inside), combine them into a
Expand All @@ -1501,8 +1528,11 @@ def merge_inner_detections_objects(
"""
detections_1 = detections[0]
for detections_2 in detections[1:]:
box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0]
if box_iou < threshold:
if detections_1.mask is not None and detections_2.mask is not None:
iou = mask_iou_batch(detections_1.mask, detections_2.mask, match_metric)[0]
else:
iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy, match_metric)[0]
if iou < threshold:
break
detections_1 = merge_inner_detection_object_pair(detections_1, detections_2)
return detections_1
Expand Down
137 changes: 130 additions & 7 deletions supervision/detection/overlap_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def mask_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
match_metric: str = "IOU",
mask_dimension: int = 640,
) -> np.ndarray:
"""
Expand All @@ -57,6 +58,8 @@ def mask_non_max_suppression(
dimensions of each mask.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values. Defaults to 640.

Expand All @@ -81,7 +84,7 @@ def mask_non_max_suppression(
predictions = predictions[sort_index]
masks = masks[sort_index]
masks_resized = resize_masks(masks, mask_dimension)
ious = mask_iou_batch(masks_resized, masks_resized)
ious = mask_iou_batch(masks_resized, masks_resized, match_metric)
categories = predictions[:, 5]

keep = np.ones(rows, dtype=bool)
Expand All @@ -94,7 +97,7 @@ def mask_non_max_suppression(


def box_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
predictions: np.ndarray, iou_threshold: float = 0.5, match_metric: str = "IOU"
) -> np.ndarray:
"""
Perform Non-Maximum Suppression (NMS) on object detection predictions.
Expand All @@ -105,6 +108,8 @@ def box_non_max_suppression(
or `(x_min, y_min, x_max, y_max, score, class)`.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
np.ndarray: A boolean array indicating which predictions to keep after n
Expand All @@ -130,7 +135,7 @@ def box_non_max_suppression(

boxes = predictions[:, :4]
categories = predictions[:, 5]
ious = box_iou_batch(boxes, boxes)
ious = box_iou_batch(boxes, boxes, match_metric)
ious = ious - np.eye(rows)

keep = np.ones(rows, dtype=bool)
Expand All @@ -148,7 +153,9 @@ def box_non_max_suppression(


def group_overlapping_boxes(
predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
predictions: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging to avoid detecting too many
Expand All @@ -160,6 +167,8 @@ def group_overlapping_boxes(
and the confidence scores.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Expand All @@ -179,7 +188,9 @@ def group_overlapping_boxes(
break

merge_candidate = np.expand_dims(predictions[idx], axis=0)
ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
ious = box_iou_batch(
predictions[order][:, :4], merge_candidate[:, :4], match_metric
)
ious = ious.flatten()

above_threshold = ious >= iou_threshold
Expand All @@ -189,9 +200,72 @@ def group_overlapping_boxes(
return merge_groups


def mask_non_max_merge(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
mask_dimension: int = 640,
match_metric: str = "IOU",
) -> np.ndarray:
"""
Perform Non-Maximum Merging (NMM) on segmentation predictions.

Args:
predictions (np.ndarray): A 2D array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
`(N, 6)`, where N is the number of predictions.
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each mask.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values. Defaults to 640.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
np.ndarray: A boolean array indicating which predictions to keep after
non-maximum suppression.

Raises:
AssertionError: If `iou_threshold` is not within the closed
range from `0` to `1`.
"""
masks_resized = resize_masks(masks, mask_dimension)
if predictions.shape[1] == 5:
return group_overlapping_masks(
predictions, masks_resized, iou_threshold, match_metric
)

category_ids = predictions[:, 5]
merge_groups = []
for category_id in np.unique(category_ids):
curr_indices = np.where(category_ids == category_id)[0]
merge_class_groups = group_overlapping_masks(
predictions[curr_indices],
masks_resized[curr_indices],
iou_threshold,
match_metric,
)

for merge_class_group in merge_class_groups:
merge_groups.append(curr_indices[merge_class_group].tolist())

for merge_group in merge_groups:
if len(merge_group) == 0:
raise ValueError(
f"Empty group detected when non-max-merging "
f"detections: {merge_groups}"
)
return merge_groups


def box_non_max_merge(
predictions: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging per category to avoid detecting
Expand All @@ -204,20 +278,22 @@ def box_non_max_merge(
detections of different classes to be merged.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Each group may have 1 or more elements.
"""
if predictions.shape[1] == 5:
return group_overlapping_boxes(predictions, iou_threshold)
return group_overlapping_boxes(predictions, iou_threshold, match_metric)

category_ids = predictions[:, 5]
merge_groups = []
for category_id in np.unique(category_ids):
curr_indices = np.where(category_ids == category_id)[0]
merge_class_groups = group_overlapping_boxes(
predictions[curr_indices], iou_threshold
predictions[curr_indices], iou_threshold, match_metric
)

for merge_class_group in merge_class_groups:
Expand All @@ -232,6 +308,53 @@ def box_non_max_merge(
return merge_groups


def group_overlapping_masks(
predictions: npt.NDArray[np.float64],
masks: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging to avoid detecting too many

Args:
predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
and the confidence scores.
masks (npt.NDArray[np.float64]): A 3D array of binary masks corresponding to the predictions.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Each group may have 1 or more elements.
"""
merge_groups: List[List[int]] = []

scores = predictions[:, 4]
order = scores.argsort()

while len(order) > 0:
idx = int(order[-1])

order = order[:-1]
if len(order) == 0:
merge_groups.append([idx])
break

merge_candidate = np.expand_dims(masks[idx], axis=0)
ious = mask_iou_batch(masks[order], merge_candidate, match_metric)
ious = ious.flatten()

above_threshold = ious >= iou_threshold
merge_group = [idx, *np.flip(order[above_threshold]).tolist()]
merge_groups.append(merge_group)
order = order[~above_threshold]
return merge_groups


class OverlapFilter(Enum):
"""
Enum specifying the strategy for filtering overlapping detections.
Expand Down
Loading