Skip to content

Commit

Permalink
Merge pull request #54 from Giskard-AI/fix-perturbation-report
Browse files Browse the repository at this point in the history
Fix perturbation report
  • Loading branch information
rabah-khalek authored Aug 13, 2024
2 parents b9ab6b8 + f8b216c commit 0af2c9b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
22 changes: 15 additions & 7 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ class ScanResult:
relative_delta: float
issue_group: Optional[IssueGroup] = None

def get_meta_required(self) -> dict:
# Get the meta required by the original scan API
def get_meta_required(self, include_slice_size=True) -> dict:
deviation = f"{self.relative_delta * 100:+.2f}% than global"
return {

meta_info = {
"metric": self.metric_name,
"metric_value": self.metric_value,
"metric_reference_value": self.metric_reference_value,
"deviation": deviation,
"slice_size": self.slice_size,
}

if include_slice_size:
meta_info["slice_size"] = self.slice_size

return meta_info


class DetectorVisionBase(DetectorSpecsBase):
"""
Expand Down Expand Up @@ -115,16 +119,20 @@ def get_issues(

for result in results:
if result.issue_level in issue_levels:
extra_args = {
"slicing_fn" if self.slicing else "transformation_fn": result.name,
"meta": result.get_meta_required(self.slicing),
}

issues.append(
Issue(
model,
dataset,
level=result.issue_level,
slicing_fn=result.name,
group=result.issue_group if result.issue_group else self.issue_group,
meta=result.get_meta_required(),
group=result.issue_group or self.issue_group,
scan_examples=ImagesScanExamples(result.filename_examples, embed=embed),
display_footer_info=False,
**extra_args,
)
)

Expand Down
9 changes: 4 additions & 5 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase
Expand All @@ -29,6 +28,7 @@ class PerturbationBaseDetector(DetectorVisionBase):
"""

issue_group = Robustness
slicing = False

def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
Expand Down Expand Up @@ -64,10 +64,9 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:

index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0]

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
Expand Down
1 change: 1 addition & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class DetectorSpecsBase:
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
slicing: bool = True
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from .perturbation import PerturbationBaseDetector


@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
@maybe_detector(
"coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "coloring"]
)
class TransformationColorDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance depending on images in grayscale
Expand Down

0 comments on commit 0af2c9b

Please sign in to comment.