diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 38ddc7abc8..34c14687ae 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -10,7 +10,6 @@ import pickle import shutil import warnings -from collections import defaultdict from enum import Enum from pathlib import Path from typing import Any, Optional @@ -22,8 +21,6 @@ from ml_wrappers import wrap_model from ml_wrappers.common.constants import Device from torchmetrics.detection.mean_ap import MeanAveragePrecision -from vision_explanation_methods.error_labeling.error_labeling import ( - ErrorLabeling, ErrorLabelType) from erroranalysis._internal.cohort_filter import FilterDataWithCohortFilters from raiutils.data_processing import convert_to_list @@ -47,7 +44,8 @@ from responsibleai_vision.utils.image_reader import ( get_base64_string_from_path, get_image_from_path, is_automl_image_model) from responsibleai_vision.utils.image_utils import ( - convert_images, get_images, transform_object_detection_labels) + convert_images, generate_od_error_labels, get_images, + transform_object_detection_labels) IMAGE = ImageColumns.IMAGE.value IMAGE_URL = ImageColumns.IMAGE_URL.value @@ -85,10 +83,6 @@ _TIME_SERIES_ID_FEATURES = 'time_series_id_features' _CATEGORICAL_FEATURES = 'categorical_features' _DROPPED_FEATURES = 'dropped_features' -_INCORRECT = 'incorrect' -_CORRECT = 'correct' -_AGGREGATE_LABEL = 'aggregate' -_NOLABEL = '(none)' def reshape_image(image): @@ -701,7 +695,7 @@ def _get_dataset(self): ) dashboard_dataset.object_detection_labels = \ - self._generate_od_error_labels( + generate_od_error_labels( dashboard_dataset.object_detection_true_y, dashboard_dataset.object_detection_predicted_y, class_names=dashboard_dataset.class_names @@ -709,68 +703,6 @@ def _get_dataset(self): return dashboard_dataset - def _generate_od_error_labels(self, true_y, pred_y, class_names): - """Utilized Error Labeling to generate labels - with correct and incorrect objects. - - :param true_y: The true labels. - :type true_y: list - :param pred_y: The predicted labels. - :type pred_y: list - :param class_names: The class labels in the dataset. - :type class_names: list - :return: The aggregated labels. - :rtype: List[str] - """ - object_detection_labels = [] - for image_idx in range(len(true_y)): - image_labels = defaultdict(lambda: defaultdict(int)) - rendered_labels = {} - error_matrix = ErrorLabeling( - ModelTask.OBJECT_DETECTION, - pred_y[image_idx], - true_y[image_idx] - ).compute_error_labels() - - for label_idx in range(len(error_matrix)): - object_label = class_names[ - int(true_y[image_idx][label_idx][0] - 1)] - if ErrorLabelType.MATCH in error_matrix[label_idx]: - image_labels[_CORRECT][object_label] += 1 - else: - image_labels[_INCORRECT][object_label] += 1 - - duplicate_detections = np.count_nonzero( - error_matrix[label_idx] == - ErrorLabelType.DUPLICATE_DETECTION) - if duplicate_detections > 0: - image_labels[_INCORRECT][object_label] += \ - duplicate_detections - - correct_labels = sorted(image_labels[_CORRECT].items(), - key=lambda x: class_names.index(x[0])) - incorrect_labels = sorted(image_labels[_INCORRECT].items(), - key=lambda x: class_names.index(x[0])) - - rendered_labels[_CORRECT] = ', '.join( - f'{value} {key}' for key, value in - correct_labels) - if len(rendered_labels[_CORRECT]) == 0: - rendered_labels[_CORRECT] = _NOLABEL - rendered_labels[_INCORRECT] = ', '.join( - f'{value} {key}' for key, value in - incorrect_labels) - if len(rendered_labels[_INCORRECT]) == 0: - rendered_labels[_INCORRECT] = _NOLABEL - rendered_labels[_AGGREGATE_LABEL] = \ - f"{sum(image_labels[_CORRECT].values())} {_CORRECT}, \ - {sum(image_labels[_INCORRECT].values())} \ - {_INCORRECT}" - - object_detection_labels.append(rendered_labels) - - return object_detection_labels - def _format_od_labels(self, y, class_names): """Formats the Object Detection label representation to multi-label image classification to follow the UI format diff --git a/responsibleai_vision/responsibleai_vision/utils/image_utils.py b/responsibleai_vision/responsibleai_vision/utils/image_utils.py index f5904ceec2..dd9c23f768 100644 --- a/responsibleai_vision/responsibleai_vision/utils/image_utils.py +++ b/responsibleai_vision/responsibleai_vision/utils/image_utils.py @@ -3,9 +3,13 @@ """Contains image handling utilities.""" +from collections import defaultdict + import numpy as np +from vision_explanation_methods.error_labeling.error_labeling import ( + ErrorLabeling, ErrorLabelType) -from responsibleai_vision.common.constants import ImageColumns +from responsibleai_vision.common.constants import ImageColumns, ModelTask from responsibleai_vision.utils.image_reader import get_image_from_path IMAGE = ImageColumns.IMAGE.value @@ -19,6 +23,10 @@ BOTTOM_X = 'bottomX' BOTTOM_Y = 'bottomY' IS_CROWD = 'isCrowd' +_INCORRECT = 'incorrect' +_CORRECT = 'correct' +_AGGREGATE_LABEL = 'aggregate' +_NOLABEL = '(none)' def convert_images(dataset, image_mode): @@ -141,3 +149,62 @@ def transform_object_detection_labels(test, target_column, classes): err = invalid_msg + 'Image details and label must be present' raise ValueError(err) return test + + +def generate_od_error_labels(true_y, pred_y, class_names): + """Utilized Error Labeling to generate labels + with correct and incorrect objects. + + :param true_y: The true labels. + :type true_y: list + :param pred_y: The predicted labels. + :type pred_y: list + :param class_names: The class labels in the dataset. + :type class_names: list + :return: The aggregated labels. + :rtype: List[str] + """ + object_detection_labels = [] + for image_idx in range(len(true_y)): + image_labels = defaultdict(lambda: defaultdict(int)) + rendered_labels = {} + error_matrix = ErrorLabeling( + ModelTask.OBJECT_DETECTION, + pred_y[image_idx], + true_y[image_idx] + ).compute_error_labels() + for label_idx in range(len(error_matrix)): + object_label = class_names[ + int(true_y[image_idx][label_idx][0] - 1)] + if ErrorLabelType.MATCH in error_matrix[label_idx]: + image_labels[_CORRECT][object_label] += 1 + else: + image_labels[_INCORRECT][object_label] += 1 + + duplicate_detections = np.count_nonzero( + error_matrix[label_idx] == + ErrorLabelType.DUPLICATE_DETECTION) + if duplicate_detections > 0: + image_labels[_INCORRECT][object_label] += \ + duplicate_detections + correct_labels = sorted(image_labels[_CORRECT].items(), + key=lambda x: class_names.index(x[0])) + incorrect_labels = sorted(image_labels[_INCORRECT].items(), + key=lambda x: class_names.index(x[0])) + rendered_labels[_CORRECT] = ', '.join( + f'{value} {key}' for key, value in + correct_labels) + if len(rendered_labels[_CORRECT]) == 0: + rendered_labels[_CORRECT] = _NOLABEL + rendered_labels[_INCORRECT] = ', '.join( + f'{value} {key}' for key, value in + incorrect_labels) + if len(rendered_labels[_INCORRECT]) == 0: + rendered_labels[_INCORRECT] = _NOLABEL + num_correct = sum(image_labels[_CORRECT].values()) + num_incorrect = sum(image_labels[_INCORRECT].values()) + agg_label = f"{num_correct} {_CORRECT}, {num_incorrect} {_INCORRECT}" + rendered_labels[_AGGREGATE_LABEL] = agg_label + object_detection_labels.append(rendered_labels) + + return object_detection_labels diff --git a/responsibleai_vision/tests/test_image_utils.py b/responsibleai_vision/tests/test_image_utils.py index 7e749d6935..db137c8e26 100644 --- a/responsibleai_vision/tests/test_image_utils.py +++ b/responsibleai_vision/tests/test_image_utils.py @@ -17,7 +17,8 @@ _requests_sessions as image_reader_requests_sessions from responsibleai_vision.utils.image_reader import get_all_exif_feature_names from responsibleai_vision.utils.image_utils import ( - BOTTOM_X, BOTTOM_Y, HEIGHT, IS_CROWD, TOP_X, TOP_Y, WIDTH, classes_to_dict, + _NOLABEL, BOTTOM_X, BOTTOM_Y, HEIGHT, IS_CROWD, TOP_X, TOP_Y, WIDTH, + classes_to_dict, generate_od_error_labels, transform_object_detection_labels) LABEL = ImageColumns.LABEL.value @@ -99,3 +100,35 @@ def test_get_all_exif_feature_names(self): set(['Orientation', 'ExifOffset', 'ImageWidth', 'GPSInfo', 'Model', 'DateTime', 'YCbCrPositioning', 'ImageLength', 'ResolutionUnit', 'Software', 'Make']) + + def test_generate_od_error_labels(self): + true_y = np.array([[[3, 142, 257, 395, 463, 0]], + [[3, 107, 272, 240, 501, 0], + [1, 261, 274, 393, 449, 0]], + [[4, 139, 253, 339, 506, 0]], + [[2, 100, 173, 233, 521, 0]], + [[1, 175, 253, 355, 416, 0]], + [[2, 86, 102, 216, 439, 0], + [3, 150, 377, 445, 490, 0]], + [[3, 103, 272, 358, 475, 0]], + [[4, 65, 289, 436, 414, 0]], + [[1, 130, 271, 367, 467, 0]], + [[1, 144, 260, 318, 429, 0]]]) + pred_y = np.array([[[3, 140, 260, 396, 469, 0]], + [[3, 108, 270, 237, 505, 0], + [1, 259, 271, 401, 450, 0]], + [[4, 131, 250, 330, 485, 0]], + [[2, 97, 170, 241, 516, 0]], + [[1, 175, 250, 354, 414, 0]], + [[2, 83, 98, 222, 445, 0], + [3, 130, 366, 438, 496, 0]], + [[3, 104, 265, 360, 468, 0]], + [[4, 58, 284, 483, 420, 0]], + [[1, 128, 265, 367, 471, 0]], + [[1, 137, 260, 325, 430, 0]]]) + class_names = ["can", "carton", "milk_bottle", "water_bottle"] + error_labels = generate_od_error_labels(true_y, pred_y, class_names) + assert len(error_labels) == 10 + assert error_labels[0]['aggregate'] == "1 correct, 0 incorrect" + assert error_labels[0]['correct'] == "1 milk_bottle" + assert error_labels[0]['incorrect'] == _NOLABEL