Skip to content

Commit

Permalink
Add mAR docstrings & examples
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Nov 7, 2024
1 parent 83d4386 commit 0bfcdea
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 6 deletions.
1 change: 0 additions & 1 deletion docs/metrics/f1_score.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
comments: true
status: new
---

# F1 Score
Expand Down
1 change: 0 additions & 1 deletion docs/metrics/mean_average_precision.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
comments: true
status: new
---

# Mean Average Precision
Expand Down
18 changes: 18 additions & 0 deletions docs/metrics/mean_average_recall.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
comments: true
status: new
---

# Mean Average Recall

<div class="md-typeset">
<h2><a href="#supervision.metrics.mean_average_recall.MeanAverageRecall">MeanAverageRecall</a></h2>
</div>

:::supervision.metrics.mean_average_recall.MeanAverageRecall

<div class="md-typeset">
<h2><a href="#supervision.metrics.mean_average_recall.MeanAverageRecallResult">MeanAverageRecallResult</a></h2>
</div>

:::supervision.metrics.mean_average_recall.MeanAverageRecallResult
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ nav:
- Utils: datasets/utils.md
- Metrics:
- mAP: metrics/mean_average_precision.md
- mAR: metrics/mean_average_recall.md
- Precision: metrics/precision.md
- Recall: metrics/recall.md
- F1 Score: metrics/f1_score.md
Expand Down
189 changes: 187 additions & 2 deletions supervision/metrics/mean_average_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,64 @@


class MeanAverageRecall(Metric):
"""
Mean Average Recall (mAR) measures how well the model detects
and retrieves relevant objects by averaging recall over multiple
IoU thresholds, classes and detection limits.
Intuitively, while Recall measures the ability to find all relevant
objects, mAR narrows down how many detections are considered for each
class. For example, mAR @ 100 considers the top 100 highest confidence
detections for each class. mAR @ 1 considers only the highest
confidence detection for each class.
Example:
```python
import supervision as sv
from supervision.metrics import MeanAverageRecall
predictions = sv.Detections(...)
targets = sv.Detections(...)
map_metric = MeanAverageRecall()
map_result = map_metric.update(predictions, targets).compute()
print(mar_results.mar_at_100)
# 0.5241
print(mar_results)
# MeanAverageRecallResult:
# Metric target: MetricTarget.BOXES
# mAR @ 1: 0.1362
# mAR @ 10: 0.4239
# mAR @ 100: 0.5241
# max detections: [1 10 100]
# IoU thresh: [0.5 0.55 0.6 ...]
# mAR per class:
# 0: [0.78571 0.78571 0.78571 ...]
# ...
# Small objects: ...
# Medium objects: ...
# Large objects: ...
mar_results.plot()
```
![example_plot](\
https://media.roboflow.com/supervision-docs/metrics/mAR_plot_example.png\
){ align=center width="800" }
"""

def __init__(
self,
metric_target: MetricTarget = MetricTarget.BOXES,
):
"""
Initialize the Mean Average Recall metric.
Args:
metric_target (MetricTarget): The type of detection data to use.
"""
self._metric_target = metric_target

self._predictions_list: List[Detections] = []
Expand All @@ -39,6 +93,9 @@ def __init__(
self.max_detections = np.array([1, 10, 100])

def reset(self) -> None:
"""
Reset the metric to its initial state, clearing all stored data.
"""
self._predictions_list = []
self._targets_list = []

Expand All @@ -47,6 +104,16 @@ def update(
predictions: Union[Detections, List[Detections]],
targets: Union[Detections, List[Detections]],
) -> MeanAverageRecall:
"""
Add new predictions and targets to the metric, but do not compute the result.
Args:
predictions (Union[Detections, List[Detections]]): The predicted detections.
targets (Union[Detections, List[Detections]]): The target detections.
Returns:
(Recall): The updated metric instance.
"""
if not isinstance(predictions, list):
predictions = [predictions]
if not isinstance(targets, list):
Expand All @@ -64,6 +131,13 @@ def update(
return self

def compute(self) -> MeanAverageRecallResult:
"""
Calculate the Mean Average Recall metric based on the stored predictions
and ground-truth, at different IoU thresholds and maximum detection counts.
Returns:
(MeanAverageRecallResult): The Mean Average Recall metric result.
"""
result = self._compute(self._predictions_list, self._targets_list)

small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
Expand Down Expand Up @@ -237,6 +311,29 @@ def _compute_confusion_matrix(
class_counts: np.ndarray,
max_detections: Optional[int] = None,
) -> np.ndarray:
"""
Compute the confusion matrix for each class and IoU threshold.
Assumes the matches and prediction_class_ids are sorted by confidence
in descending order.
Args:
sorted_matches: np.ndarray, bool, shape (P, Th), that is True
if the prediction is a true positive at the given IoU threshold.
sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
the class id for each prediction.
unique_classes: np.ndarray, int, shape (C,), containing the unique
class ids.
class_counts: np.ndarray, int, shape (C,), containing the number
of true instances for each class.
max_detections: Optional[int], the maximum number of detections to
consider for each class. Extra detections are considered false
positives. By default, all detections are considered.
Returns:
np.ndarray, shape (C, Th, 3), containing the true positives, false
positives, and false negatives for each class and IoU threshold.
"""
num_thresholds = sorted_matches.shape[1]
num_classes = unique_classes.shape[0]

Expand Down Expand Up @@ -364,6 +461,61 @@ def _filter_predictions_and_targets_by_size(

@dataclass
class MeanAverageRecallResult:
# """
# The results of the recall metric calculation.

# Defaults to `0` if no detections or targets were provided.

# Attributes:
# metric_target (MetricTarget): the type of data used for the metric -
# boxes, masks or oriented bounding boxes.
# averaging_method (AveragingMethod): the averaging method used to compute the
# recall. Determines how the recall is aggregated across classes.
# recall_at_50 (float): the recall at IoU threshold of `0.5`.
# recall_at_75 (float): the recall at IoU threshold of `0.75`.
# recall_scores (np.ndarray): the recall scores at each IoU threshold.
# Shape: `(num_iou_thresholds,)`
# recall_per_class (np.ndarray): the recall scores per class and IoU threshold.
# Shape: `(num_target_classes, num_iou_thresholds)`
# iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
# matched_classes (np.ndarray): the class IDs of all matched classes.
# Corresponds to the rows of `recall_per_class`.
# small_objects (Optional[RecallResult]): the Recall metric results
# for small objects.
# medium_objects (Optional[RecallResult]): the Recall metric results
# for medium objects.
# large_objects (Optional[RecallResult]): the Recall metric results
# for large objects.
# """
"""
The results of the Mean Average Recall metric calculation.
Defaults to `0` if no detections or targets were provided.
Attributes:
metric_target (MetricTarget): the type of data used for the metric -
boxes, masks or oriented bounding boxes.
mAR_at_1 (float): the Mean Average Recall, when considering only the top
highest confidence detection for each class.
mAR_at_10 (float): the Mean Average Recall, when considering top 10
highest confidence detections for each class.
mAR_at_100 (float): the Mean Average Recall, when considering top 100
highest confidence detections for each class.
recall_per_class (np.ndarray): the recall scores per class and IoU threshold.
Shape: `(num_target_classes, num_iou_thresholds)`
max_detections (np.ndarray): the array with maximum number of detections
considered.
iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
matched_classes (np.ndarray): the class IDs of all matched classes.
Corresponds to the rows of `recall_per_class`.
small_objects (Optional[MeanAverageRecallResult]): the Mean Average Recall
metric results for small objects (area < 32²).
medium_objects (Optional[MeanAverageRecallResult]): the Mean Average Recall
metric results for medium objects (32² ≤ area < 96²).
large_objects (Optional[MeanAverageRecallResult]): the Mean Average Recall
metric results for large objects (area ≥ 96²).
"""

metric_target: MetricTarget

@property
Expand All @@ -389,9 +541,29 @@ def mAR_at_100(self) -> float:
large_objects: Optional[MeanAverageRecallResult]

def __str__(self) -> str:
"""
Format as a pretty string.
Example:
```python
# MeanAverageRecallResult:
# Metric target: MetricTarget.BOXES
# mAR @ 1: 0.1362
# mAR @ 10: 0.4239
# mAR @ 100: 0.5241
# max detections: [1 10 100]
# IoU thresh: [0.5 0.55 0.6 ...]
# mAR per class:
# 0: [0.78571 0.78571 0.78571 ...]
# ...
# Small objects: ...
# Medium objects: ...
# Large objects: ...
```
"""
out_str = (
f"{self.__class__.__name__}:\n"
f"Metric target: {self.metric_target}\n"
f"Metric target: {self.metric_target}\n"
f"mAR @ 1: {self.mAR_at_1:.4f}\n"
f"mAR @ 10: {self.mAR_at_10:.4f}\n"
f"mAR @ 100: {self.mAR_at_100:.4f}\n"
Expand Down Expand Up @@ -420,6 +592,12 @@ def __str__(self) -> str:
return out_str

def to_pandas(self) -> "pd.DataFrame":
"""
Convert the result to a pandas DataFrame.
Returns:
(pd.DataFrame): The result as a DataFrame.
"""
ensure_pandas_installed()
import pandas as pd

Expand All @@ -445,6 +623,13 @@ def to_pandas(self) -> "pd.DataFrame":
return pd.DataFrame(pandas_data, index=[0])

def plot(self):
"""
Plot the Mean Average Recall results.
![example_plot](\
https://media.roboflow.com/supervision-docs/metrics/mAR_plot_example.png\
){ align=center width="800" }
"""
labels = ["mAR @ 1", "mAR @ 10", "mAR @ 100"]
values = [self.mAR_at_1, self.mAR_at_10, self.mAR_at_100]
colors = [LEGACY_COLOR_PALETTE[0]] * 3
Expand Down Expand Up @@ -486,7 +671,7 @@ def plot(self):
ax.set_ylabel("Value", fontweight="bold")
title = (
f"Mean Average Recall, by Object Size"
f"\n(target: {self.metric_target.value}"
f"\n(target: {self.metric_target.value})"
)
ax.set_title(title, fontweight="bold")

Expand Down
4 changes: 2 additions & 2 deletions supervision/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def update(

def compute(self) -> RecallResult:
"""
Calculate the precision metric based on the stored predictions and ground-truth
Calculate the recall metric based on the stored predictions and ground-truth
data, at different IoU thresholds.
Returns:
(RecallResult): The precision metric result.
(RecallResult): The recall metric result.
"""
result = self._compute(self._predictions_list, self._targets_list)

Expand Down

0 comments on commit 0bfcdea

Please sign in to comment.