-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from huridocs/segment-selector-performance
Segment selector performance
- Loading branch information
Showing
9 changed files
with
263 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
``` | ||
+----------------------------------------------------------------------------------------+ | ||
|method |dataset | precision| recall| seconds| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |date | 70.37| 27.54| 1| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |date | 97.1| 97.1| 6| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |decides | 66.67| 17.07| 1| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |decides | 84.62| 73.33| 42| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |first_paragraph_having_seen | 66.67| 2.9| 2| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |first_paragraph_having_seen | 100.0| 94.2| 18| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |plan_many_date | 51.35| 23.75| 3| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |plan_many_date | 97.56| 100.0| 12| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |plan_many_title | 90.48| 43.18| 3| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |plan_many_title | 86.67| 88.64| 11| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |president | 90.48| 27.94| 1| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |president | 90.28| 95.59| 6| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |rightdocs_titles | 0.0| 0.0| 0| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |rightdocs_titles | 96.97| 88.89| 3| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |secretary | 0.0| 0.0| 1| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |secretary | 97.01| 97.01| 6| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |semantic_president | 0.0| 0.0| 18| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |semantic_president | 98.96| 60.2| 210| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |signatories | 0.0| 0.0| 1| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |signatories | 97.58| 94.15| 7| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|FastSegmentSelector |Average | 43.6| 14.24| 3| | ||
+------------------------+--------------------------------+----------+----------+--------+ | ||
|SegmentSelector |Average | 94.67| 88.91| 32| | ||
+----------------------------------------------------------------------------------------+ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from data.ExtractionIdentifier import ExtractionIdentifier | ||
from data.PdfData import PdfData | ||
|
||
|
||
class SegmentSelectorBase(ABC): | ||
|
||
def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""): | ||
self.extraction_identifier = extraction_identifier | ||
self.method_name = method_name | ||
|
||
@abstractmethod | ||
def prepare_model_folder(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_predictions_for_performance(self, training_set: list[PdfData], test_set: list[PdfData]) -> list[int]: | ||
pass | ||
|
||
def get_name(self): | ||
return self.__class__.__name__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class SegmentSelectorResults(BaseModel): | ||
method: str | ||
dataset: str | ||
precision: float | ||
recall: float | ||
seconds: int | ||
|
||
@staticmethod | ||
def get_padding(): | ||
return {"method": "right", "dataset": "right", "precision": "left", "recall": "left", "seconds": "left"} |
69 changes: 69 additions & 0 deletions
69
src/extractors/segment_selector/get_data_for_performance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import json | ||
import pickle | ||
from os import listdir | ||
from os.path import join | ||
from pathlib import Path | ||
|
||
from config import ROOT_PATH | ||
from data.PdfData import PdfData | ||
from data.SegmentBox import SegmentBox | ||
from data.SegmentationData import SegmentationData | ||
from performance_pdf_to_multi_option_report import PDF_DATA_FOLDER_PATH, cache_pdf_data | ||
|
||
LABELED_DATA_PATH = Path(ROOT_PATH.parent, "pdf-labeled-data", "labeled_data", "paragraph_selector") | ||
|
||
DATASETS = [ | ||
"date", | ||
"decides", | ||
"first_paragraph_having_seen", | ||
"plan_many_date", | ||
"plan_many_title", | ||
"president", | ||
"rightdocs_titles", | ||
"secretary", | ||
"semantic_president", | ||
"signatories", | ||
] | ||
|
||
|
||
def get_data_for_performance(filter_datasets: list[str] = None) -> dict[str, list[PdfData]]: | ||
if filter_datasets: | ||
filtered_datasets = [x for x in DATASETS if x in filter_datasets] | ||
else: | ||
filtered_datasets = DATASETS | ||
|
||
pdf_data_per_dataset = {} | ||
for dataset in filtered_datasets: | ||
pdf_data_per_dataset[dataset] = [] | ||
for pdf_name in listdir(Path(LABELED_DATA_PATH, dataset)): | ||
pickle_path = join(PDF_DATA_FOLDER_PATH, f"{pdf_name}.pickle") | ||
if Path(pickle_path).exists(): | ||
with open(pickle_path, mode="rb") as file: | ||
pdf_data: PdfData = pickle.load(file) | ||
else: | ||
pdf_data: PdfData = cache_pdf_data(pdf_name, Path(pickle_path)) | ||
|
||
segmentation_data = SegmentationData( | ||
page_width=0, page_height=0, xml_segments_boxes=[], label_segments_boxes=get_labels(dataset, pdf_name) | ||
) | ||
pdf_data.set_ml_label_from_segmentation_data(segmentation_data) | ||
pdf_data_per_dataset[dataset].append(pdf_data) | ||
|
||
return pdf_data_per_dataset | ||
|
||
|
||
def get_labels(dataset, pdf_name): | ||
label_segments_boxes = [] | ||
labels = json.loads(Path(LABELED_DATA_PATH, dataset, pdf_name, "labels.json").read_text()) | ||
for page in labels["pages"]: | ||
for label in page["labels"]: | ||
label_segments_boxes.append( | ||
SegmentBox( | ||
left=label["left"], | ||
top=label["top"], | ||
width=label["width"], | ||
height=label["height"], | ||
page_number=page["number"], | ||
) | ||
) | ||
return label_segments_boxes |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from time import time | ||
|
||
from sklearn.metrics import precision_score, recall_score, average_precision_score | ||
from py_markdown_table.markdown_table import markdown_table | ||
from data.ExtractionIdentifier import ExtractionIdentifier | ||
from extractors.segment_selector.FastSegmentSelector import FastSegmentSelector | ||
from extractors.segment_selector.SegmentSelector import SegmentSelector | ||
from extractors.segment_selector.SegmentSelectorBase import SegmentSelectorBase | ||
from extractors.segment_selector.SegmentSelectorResults import SegmentSelectorResults | ||
from extractors.segment_selector.get_data_for_performance import get_data_for_performance | ||
|
||
extraction_identifier = ExtractionIdentifier(run_name="benchmark", extraction_name="segment_selector") | ||
|
||
METHODS: list[SegmentSelectorBase] = [FastSegmentSelector(extraction_identifier), SegmentSelector(extraction_identifier)] | ||
|
||
|
||
def get_train_test(pdfs_data): | ||
train_number = int(len(pdfs_data) * 0.5) | ||
train_data = pdfs_data[:train_number] | ||
test_data = pdfs_data[train_number:] | ||
return train_data, test_data | ||
|
||
|
||
def print_results(results): | ||
results.sort(key=lambda x: (x.dataset, x.method)) | ||
for method in METHODS: | ||
precisions = [x.precision for x in results if x.method == method.get_name()] | ||
recalls = [x.recall for x in results if x.method == method.get_name()] | ||
seconds = [x.seconds for x in results if x.method == method.get_name()] | ||
average_precision = round(sum(precisions) / len(precisions), 2) | ||
average_recall = round(sum(recalls) / len(recalls), 2) | ||
average_seconds = round(sum(seconds) / len(seconds)) | ||
|
||
results.append( | ||
SegmentSelectorResults( | ||
method=method.get_name(), | ||
dataset="Average", | ||
precision=average_precision, | ||
recall=average_recall, | ||
seconds=average_seconds, | ||
) | ||
) | ||
|
||
data = [x.model_dump() for x in results] | ||
padding = SegmentSelectorResults.get_padding() | ||
markdown = markdown_table(data).set_params(padding_width=5, padding_weight=padding).get_markdown() | ||
print(markdown) | ||
|
||
|
||
def get_performance_segment_selector(): | ||
data = get_data_for_performance(filter_datasets=[]) | ||
print(f"Datasets: {data.keys()}") | ||
|
||
results: list[SegmentSelectorResults] = list() | ||
for dataset, pdfs_data in data.items(): | ||
training_set, test_set = get_train_test(pdfs_data) | ||
|
||
truth = [x.ml_label for pdf_data in test_set for x in pdf_data.pdf_data_segments] | ||
|
||
for segment in [x for pdf_data in test_set for x in pdf_data.pdf_data_segments]: | ||
segment.ml_label = 0 | ||
|
||
for method in METHODS: | ||
method.prepare_model_folder() | ||
start = time() | ||
predicted_labels = method.get_predictions_for_performance(training_set, test_set) | ||
|
||
selector_results = SegmentSelectorResults( | ||
method=method.get_name(), | ||
dataset=dataset, | ||
precision=round(100 * precision_score(truth, predicted_labels), 2), | ||
recall=round(100 * recall_score(truth, predicted_labels), 2), | ||
seconds=round(time() - start), | ||
) | ||
|
||
results.append(selector_results) | ||
|
||
print_results(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
get_performance_segment_selector() |