Skip to content

Commit

Permalink
Establish e2e benchmarking framework (#318)
Browse files Browse the repository at this point in the history
* added report-vision to gitignore

* added segmentation ui utility file

* added new functions to batch process multiple files

* added new functions to run ocr on multiple files for e2e metrics only

* added time measurement to seg step.

* moved seg file to tests folder

* added confidence to total metrics

* edited main to create metrics file with all metrics for individual files

* created e2e benchmark flow

* added avg confidence and time to final metrics document

* linting

* added e2e as benchmark test

* replaced benchmark test with benchmark main

* removed unecessary files and updated readme

* added problematic segments logic

* updated read me

* comments

* added options to run pipeline separately

* added updated to gitignore

* removed segmentation files

* added medical_report import to tests folder

* added medical reports to tests directory

* linting

* linting 2

* linting 3

* edited comments

* linting 5

* fixed bug with matching files not being properly processed

* imported mean(avg) package for total_metrics

* minor edit to read me file

* read me edits for clarity

* added extract core keys for medical report/medical prescription files

* moved file
  • Loading branch information
arinkulshi-skylight authored Oct 25, 2024
1 parent b0e1f41 commit 3ea8ded
Show file tree
Hide file tree
Showing 11 changed files with 669 additions and 31 deletions.
10 changes: 10 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions OCR/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea
OCR/ocr/reportvision-dataset-1
28 changes: 28 additions & 0 deletions OCR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,34 @@ poetry run api

You can also run the script pytest run reportvision-dataset-1/medical_report_import.py to pull in all relevant data.


### Run E2E Benchmark Main
This will:
1.Segment and Run OCR on a folder of images using given segmentation template and labels file.
2.Compare OCR outputs to ground truth by searching for matching file names .
3.Writes metrics(Confidence,Raw Distance,Hamming Distance, Levenshtein Distance) as well as total metrics to a csv file.


To Run:
Locate file benchmark_main.py
Ensure all the paths/folders exist
https://drive.google.com/drive/folders/1WS2FYn0BTxWv0juh7lblzdMaFlI7zbDd?usp=sharing (This link for all segmentation/labels files)
Ensure ground_truth folder and files exist
Ensure labels.json is in the correct format see(tax_form_segmented_labels.json as an example)
When running make sure to pass arguments in this order:

/path/to/image/folder (path to the original image files which we need to run ocr on)
/path/to/segmentation_template.png(single_file)
/path/to/labels.json(single file)
/path/to/output/folder (path to folder where the output would be. This should exist but can be empty)
/path/to/ground/truth_folder(path to folder for metrics that we would compare against)
/path/to/csv_out_folder(path to folder where all metrics would be. This should exist but can be empty)
the last arguement is a number 1 for running segmentation and ocr 2 for metrics analysis and 3 for running both

Notes:
benchmark takes one second per segment for OCR please be patient or set a counter to limit the number of files processed
Only one segment can be inputted at a time

### Dockerized Development

It is also possible to run the entire project in a collection of docker containers. This is useful for development and testing purposes as it doesn't require any additional dependencies to be installed on your local machine.
Expand Down
34 changes: 21 additions & 13 deletions OCR/ocr/metrics_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import os


current_script_dir = os.path.dirname(os.path.abspath(__file__))
file_relative_path_ground_truth = "../tests/assets/ltbi_legacy.json"
file_relative_path_ocr = "../tests/assets/ltbi_legacy_ocr.json"
ground_truth_json_path = os.path.join(current_script_dir, file_relative_path_ground_truth)
ocr_json_path = os.path.join(current_script_dir, file_relative_path_ocr)

ocr_metrics = OCRMetrics(ocr_json_path, ground_truth_json_path)
metrics = ocr_metrics.calculate_metrics()
for m in metrics:
print(m)
overall_metrics = ocr_metrics.total_metrics(metrics)
print("Overall Metrics:", overall_metrics)
OCRMetrics.save_metrics_to_csv(metrics, "new.csv")
def main():
current_script_dir = os.path.dirname(os.path.abspath(__file__))

file_relative_path_ground_truth = ".."
file_relative_path_ocr = ".."

ground_truth_json_path = os.path.join(current_script_dir, file_relative_path_ground_truth)
ocr_json_path = os.path.join(current_script_dir, file_relative_path_ocr)

ocr_metrics = OCRMetrics(ocr_json_path, ground_truth_json_path)

metrics = ocr_metrics.calculate_metrics()

total_metrics = ocr_metrics.total_metrics(metrics)

output_csv_path = os.path.join(current_script_dir, "metrics_file.csv")
OCRMetrics.save_metrics_to_csv(metrics, total_metrics, output_csv_path)


if __name__ == "__main__":
main()
55 changes: 42 additions & 13 deletions OCR/ocr/services/metrics_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

import csv
import Levenshtein
from statistics import mean


class OCRMetrics:
Expand Down Expand Up @@ -35,6 +35,9 @@ def load_json_file(self, file_path):
def normalize(text):
if text is None:
return ""

text = str(text)

return " ".join(text.strip().lower().split())

@staticmethod
Expand All @@ -52,35 +55,48 @@ def levenshtein_distance(ocr_text, ground_truth):
return Levenshtein.distance(ocr_text, ground_truth)

def extract_values_from_json(self, json_data):
if json_data is None:
return {}
extracted_values = {}
for item in json_data:
if isinstance(item, dict) and "key" in item and "value" in item:
key = self.normalize(item["key"])
value = self.normalize(item["value"])
extracted_values[key] = value
for key, value in json_data.items():
if isinstance(value, list) and len(value) >= 2:
extracted_value, confidence = value[0], value[1]
else:
raise ValueError("Invalid JSON format")
extracted_value, confidence = value, 0 # defaults to 0% if no confidence provided.

normalized_key = self.normalize(key)
normalized_value = self.normalize(extracted_value)

extracted_values[normalized_key] = {
"value": normalized_value,
"confidence": confidence,
}

return extracted_values

def calculate_metrics(self):
ocr_values = self.extract_values_from_json(self.ocr_json)
ground_truth_values = self.extract_values_from_json(self.ground_truth_json)

metrics = []
for key in ground_truth_values:
ocr_text = ocr_values.get(key, "")
ground_truth = ground_truth_values[key]
ocr_entry = ocr_values.get(key, {"value": "", "confidence": 0.0})
ocr_text = ocr_entry["value"]
confidence = ocr_entry["confidence"]
ground_truth = ground_truth_values[key]["value"]

raw_dist = self.raw_distance(ocr_text, ground_truth)
try:
ham_dist = self.hamming_distance(ocr_text, ground_truth)
except ValueError as e:
ham_dist = str(e)
lev_dist = self.levenshtein_distance(ocr_text, ground_truth)

metrics.append(
{
"key": key,
"ocr_text": ocr_text,
"ground_truth": ground_truth,
"confidence": confidence,
"raw_distance": raw_dist,
"hamming_distance": ham_dist,
"levenshtein_distance": lev_dist,
Expand All @@ -94,6 +110,7 @@ def total_metrics(metrics):
total_levenshtein_distance = sum(
item["levenshtein_distance"] for item in metrics if isinstance(item["levenshtein_distance"], int)
)
avg_confidence = mean(item["confidence"] for item in metrics) if metrics else 0

try:
total_hamming_distance = sum(
Expand All @@ -113,12 +130,24 @@ def total_metrics(metrics):
"total_hamming_distance": total_hamming_distance,
"total_levenshtein_distance": total_levenshtein_distance,
"levenshtein_accuracy": f"{accuracy:.2f}%",
"avg_confidence": f"{avg_confidence:.2f}",
}

@staticmethod
def save_metrics_to_csv(metrics, file_path):
keys = metrics[0].keys()
def save_metrics_to_csv(metrics, total_metrics, file_path):
metric_keys = metrics[0].keys()

total_metric_keys = total_metrics.keys()

with open(file_path, "w", newline="") as output_file:
dict_writer = csv.DictWriter(output_file, fieldnames=keys)
dict_writer = csv.DictWriter(output_file, fieldnames=metric_keys)
dict_writer.writeheader()
dict_writer.writerows(metrics)

output_file.write("\n")

total_writer = csv.DictWriter(output_file, fieldnames=total_metric_keys)

total_writer.writeheader()

total_writer.writerow(total_metrics)
147 changes: 147 additions & 0 deletions OCR/tests/batch_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from ocr.services.metrics_analysis import OCRMetrics
import os
import csv


class BatchMetricsAnalysis:
def __init__(self, ocr_folder, ground_truth_folder, csv_output_folder):
self.ocr_folder = ocr_folder
self.ground_truth_folder = ground_truth_folder
self.csv_output_folder = csv_output_folder

os.makedirs(self.csv_output_folder, exist_ok=True)

def calculate_batch_metrics(self, ocr_results=None):
"""
Processes OCR and ground truth files and saves individual CSVs.
Ensures only matching files are processed.
"""
print(f"Loading OCR files from: {self.ocr_folder}")
print(f"Loading ground truth files from: {self.ground_truth_folder}")
print(f"Saving individual CSVs to: {self.csv_output_folder}")

total_metrics_summary = {}
problematic_segments = []

ocr_files = self.get_files_in_directory(self.ocr_folder)
ground_truth_files = self.get_files_in_directory(self.ground_truth_folder)

# Create dic for matching files by name
ocr_dict = {os.path.splitext(f)[0]: f for f in ocr_files}
ground_truth_dict = {os.path.splitext(f)[0]: f for f in ground_truth_files}
# Find the intersection of matching file names
matching_files = ocr_dict.keys() & ground_truth_dict.keys()
# Process only matching files
print(f"Processing matching files: {matching_files}")
for file_name in matching_files:
ocr_file = ocr_dict[file_name]
ground_truth_file = ground_truth_dict[file_name]

print(f"Processing OCR: {ocr_file} with Ground Truth: {ground_truth_file}")

ocr_path = os.path.join(self.ocr_folder, ocr_file)
ground_truth_path = os.path.join(self.ground_truth_folder, ground_truth_file)

ocr_metrics = OCRMetrics(ocr_json_path=ocr_path, ground_truth_json_path=ground_truth_path)
metrics = ocr_metrics.calculate_metrics()

self.extract_problematic_segments(metrics, ocr_file, problematic_segments)

total_metrics = ocr_metrics.total_metrics(metrics)

# Create a CSV path for this specific file pair
csv_file_name = f"{file_name}_metrics.csv"
csv_output_path = os.path.join(self.csv_output_folder, csv_file_name)

print(f"Saving metrics to: {csv_output_path}")
self.save_metrics_to_csv(metrics, total_metrics, csv_output_path)

# Store total metrics
total_metrics_summary[ocr_file] = total_metrics

# Save problematic segments to CSV
problematic_csv_path = os.path.join(self.csv_output_folder, "problematic_segments.csv")
print(f"Saving problematic segments to: {problematic_csv_path}")
self.save_problematic_segments_to_csv(problematic_segments, problematic_csv_path)

print("Finished processing all files.")
return total_metrics_summary

@staticmethod
def save_metrics_to_csv(metrics, total_metrics, file_path):
"""
Saves individual and total metrics to a CSV file, including time taken.
"""
print(metrics)
metric_keys = list(metrics[0].keys())
total_metric_keys = list(total_metrics.keys())

with open(file_path, "w", newline="") as output_file:
# Write individual metrics
dict_writer = csv.DictWriter(output_file, fieldnames=metric_keys)
dict_writer.writeheader()
dict_writer.writerows(metrics)

output_file.write("\n")

# Write total metrics
total_writer = csv.DictWriter(output_file, fieldnames=total_metric_keys)
total_writer.writeheader()
total_writer.writerow(total_metrics)

print(f"Metrics saved to {file_path}")

@staticmethod
def save_problematic_segments_to_csv(segments, file_path):
"""
Saves problematic segments (Levenshtein distance >= 1) to a CSV file.
"""
if not segments:
print("No problematic segments found.")
return

with open(file_path, "w", newline="") as output_file:
fieldnames = ["file", "key", "ocr_value", "ground_truth", "confidence", "levenshtein_distance"]
writer = csv.DictWriter(output_file, fieldnames=fieldnames)

writer.writeheader()
writer.writerows(segments)

print(f"Problematic segments saved to {file_path}")

def extract_problematic_segments(self, metrics, ocr_file, problematic_segments):
"""
Extracts segments with Levenshtein distance >= 1 and stores them.
"""
for metric in metrics:
if metric["levenshtein_distance"] >= 1:
problematic_segments.append(
{
"file": ocr_file,
"key": metric["key"],
"ocr_value": metric["ocr_text"],
"ground_truth": metric["ground_truth"],
"confidence": metric["confidence"],
"levenshtein_distance": metric["levenshtein_distance"],
}
)

@staticmethod
def get_files_in_directory(directory):
"""
Returns a sorted list of files in the specified directory.
Assumes that files are named consistently for OCR and ground truth.
"""
try:
files = sorted(
[
f
for f in os.listdir(directory)
if os.path.isfile(os.path.join(directory, f)) and not f.startswith(".")
]
)
print(f"Files found in {directory}: {files}")
return files
except FileNotFoundError as e:
print(f"Error: {e}")
return []
Loading

0 comments on commit 3ea8ded

Please sign in to comment.