Skip to content

Commit

Permalink
Merge pull request #97 from huridocs/segment-selector-performance
Browse files Browse the repository at this point in the history
Segment selector performance
  • Loading branch information
gabriel-piles authored Oct 15, 2024
2 parents abb7f7a + be33e28 commit 341aab9
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 7 deletions.
49 changes: 49 additions & 0 deletions performance_results/segment_selector.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
```
+----------------------------------------------------------------------------------------+
|method |dataset | precision| recall| seconds|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |date | 70.37| 27.54| 1|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |date | 97.1| 97.1| 6|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |decides | 66.67| 17.07| 1|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |decides | 84.62| 73.33| 42|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |first_paragraph_having_seen | 66.67| 2.9| 2|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |first_paragraph_having_seen | 100.0| 94.2| 18|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |plan_many_date | 51.35| 23.75| 3|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |plan_many_date | 97.56| 100.0| 12|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |plan_many_title | 90.48| 43.18| 3|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |plan_many_title | 86.67| 88.64| 11|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |president | 90.48| 27.94| 1|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |president | 90.28| 95.59| 6|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |rightdocs_titles | 0.0| 0.0| 0|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |rightdocs_titles | 96.97| 88.89| 3|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |secretary | 0.0| 0.0| 1|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |secretary | 97.01| 97.01| 6|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |semantic_president | 0.0| 0.0| 18|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |semantic_president | 98.96| 60.2| 210|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |signatories | 0.0| 0.0| 1|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |signatories | 97.58| 94.15| 7|
+------------------------+--------------------------------+----------+----------+--------+
|FastSegmentSelector |Average | 43.6| 14.24| 3|
+------------------------+--------------------------------+----------+----------+--------+
|SegmentSelector |Average | 94.67| 88.91| 32|
+----------------------------------------------------------------------------------------+
```
8 changes: 7 additions & 1 deletion src/extractors/ToTextExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from config import config_logger
from data.ExtractionData import ExtractionData
from data.ExtractionIdentifier import ExtractionIdentifier
from data.LogsMessage import Severity
from data.PredictionSample import PredictionSample
from data.Suggestion import Suggestion
from extractors.ExtractorBase import ExtractorBase
Expand Down Expand Up @@ -111,7 +112,12 @@ def get_best_method(self, extraction_data: ExtractionData):
for method in self.METHODS:
method_instance = method(self.extraction_identifier)
send_logs(self.extraction_identifier, f"Checking {method_instance.get_name()}")
performance = method_instance.performance(training_set, test_set)
try:
performance = method_instance.performance(training_set, test_set)
except Exception as e:
message = f"Error checking {method_instance.get_name()}: {e}"
send_logs(self.extraction_identifier, message, Severity.error)
performance = 0
performance_log += f"{method_instance.get_name()}: {round(performance, 2)}%\n"
send_logs(self.extraction_identifier, f"Performance {method_instance.get_name()}: {performance}%")
if performance == 100:
Expand Down
15 changes: 12 additions & 3 deletions src/extractors/segment_selector/FastSegmentSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
from pdf_token_type_labels.TokenType import TokenType

from data.ExtractionIdentifier import ExtractionIdentifier
from data.PdfData import PdfData
from data.PdfDataSegment import PdfDataSegment
import lightgbm as lgb

from extractors.segment_selector.SegmentSelectorBase import SegmentSelectorBase

class FastSegmentSelector:

class FastSegmentSelector(SegmentSelectorBase):
def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""):
self.extraction_identifier = extraction_identifier
super().__init__(extraction_identifier, method_name)
self.text_types = [TokenType.TEXT, TokenType.LIST_ITEM, TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.CAPTION]
self.previous_words, self.next_words, self.text_segments = [], [], []
self.method_name = method_name

if method_name:
self.fast_segment_selector_path = Path(
Expand Down Expand Up @@ -161,3 +163,10 @@ def load_repeated_words(self):

if exists(self.next_words_path):
self.next_words = json.loads(Path(self.next_words_path).read_text())

def get_predictions_for_performance(self, training_set: list[PdfData], test_set: list[PdfData]) -> list[int]:
training_segments = [x for pdf_data in training_set for x in pdf_data.pdf_data_segments]
test_segments = [x for pdf_data in test_set for x in pdf_data.pdf_data_segments]
self.create_model(training_segments)
predictions = self.predict(test_segments)
return [1 if segment in predictions else 0 for segment in test_segments]
12 changes: 9 additions & 3 deletions src/extractors/segment_selector/SegmentSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from data.ExtractionIdentifier import ExtractionIdentifier
from data.PdfData import PdfData
from extractors.segment_selector.SegmentSelectorBase import SegmentSelectorBase
from extractors.segment_selector.methods.lightgbm_frequent_words.LightgbmFrequentWords import LightgbmFrequentWords


class SegmentSelector:
def __init__(self, extraction_identifier: ExtractionIdentifier):
self.extraction_identifier = extraction_identifier
class SegmentSelector(SegmentSelectorBase):
def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""):
super().__init__(extraction_identifier, method_name)
self.model_path = join(self.extraction_identifier.get_path(), "segment_predictor_model", "model.model")
self.model = self.load_model()

Expand Down Expand Up @@ -69,3 +70,8 @@ def set_extraction_segments(self, pdfs_data: list[PdfData]):
for segment in pdf_metadata.pdf_data_segments:
segment.ml_label = 1 if predictions[index] > 0.5 else 0
index += 1

def get_predictions_for_performance(self, training_set: list[PdfData], test_set: list[PdfData]) -> list[int]:
self.create_model(training_set)
self.set_extraction_segments(test_set)
return [segment.ml_label for pdf_data in test_set for segment in pdf_data.pdf_data_segments]
22 changes: 22 additions & 0 deletions src/extractors/segment_selector/SegmentSelectorBase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod

from data.ExtractionIdentifier import ExtractionIdentifier
from data.PdfData import PdfData


class SegmentSelectorBase(ABC):

def __init__(self, extraction_identifier: ExtractionIdentifier, method_name: str = ""):
self.extraction_identifier = extraction_identifier
self.method_name = method_name

@abstractmethod
def prepare_model_folder(self):
pass

@abstractmethod
def get_predictions_for_performance(self, training_set: list[PdfData], test_set: list[PdfData]) -> list[int]:
pass

def get_name(self):
return self.__class__.__name__
13 changes: 13 additions & 0 deletions src/extractors/segment_selector/SegmentSelectorResults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel


class SegmentSelectorResults(BaseModel):
method: str
dataset: str
precision: float
recall: float
seconds: int

@staticmethod
def get_padding():
return {"method": "right", "dataset": "right", "precision": "left", "recall": "left", "seconds": "left"}
69 changes: 69 additions & 0 deletions src/extractors/segment_selector/get_data_for_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
import pickle
from os import listdir
from os.path import join
from pathlib import Path

from config import ROOT_PATH
from data.PdfData import PdfData
from data.SegmentBox import SegmentBox
from data.SegmentationData import SegmentationData
from performance_pdf_to_multi_option_report import PDF_DATA_FOLDER_PATH, cache_pdf_data

LABELED_DATA_PATH = Path(ROOT_PATH.parent, "pdf-labeled-data", "labeled_data", "paragraph_selector")

DATASETS = [
"date",
"decides",
"first_paragraph_having_seen",
"plan_many_date",
"plan_many_title",
"president",
"rightdocs_titles",
"secretary",
"semantic_president",
"signatories",
]


def get_data_for_performance(filter_datasets: list[str] = None) -> dict[str, list[PdfData]]:
if filter_datasets:
filtered_datasets = [x for x in DATASETS if x in filter_datasets]
else:
filtered_datasets = DATASETS

pdf_data_per_dataset = {}
for dataset in filtered_datasets:
pdf_data_per_dataset[dataset] = []
for pdf_name in listdir(Path(LABELED_DATA_PATH, dataset)):
pickle_path = join(PDF_DATA_FOLDER_PATH, f"{pdf_name}.pickle")
if Path(pickle_path).exists():
with open(pickle_path, mode="rb") as file:
pdf_data: PdfData = pickle.load(file)
else:
pdf_data: PdfData = cache_pdf_data(pdf_name, Path(pickle_path))

segmentation_data = SegmentationData(
page_width=0, page_height=0, xml_segments_boxes=[], label_segments_boxes=get_labels(dataset, pdf_name)
)
pdf_data.set_ml_label_from_segmentation_data(segmentation_data)
pdf_data_per_dataset[dataset].append(pdf_data)

return pdf_data_per_dataset


def get_labels(dataset, pdf_name):
label_segments_boxes = []
labels = json.loads(Path(LABELED_DATA_PATH, dataset, pdf_name, "labels.json").read_text())
for page in labels["pages"]:
for label in page["labels"]:
label_segments_boxes.append(
SegmentBox(
left=label["left"],
top=label["top"],
width=label["width"],
height=label["height"],
page_number=page["number"],
)
)
return label_segments_boxes
File renamed without changes.
82 changes: 82 additions & 0 deletions src/performance_segment_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from time import time

from sklearn.metrics import precision_score, recall_score, average_precision_score
from py_markdown_table.markdown_table import markdown_table
from data.ExtractionIdentifier import ExtractionIdentifier
from extractors.segment_selector.FastSegmentSelector import FastSegmentSelector
from extractors.segment_selector.SegmentSelector import SegmentSelector
from extractors.segment_selector.SegmentSelectorBase import SegmentSelectorBase
from extractors.segment_selector.SegmentSelectorResults import SegmentSelectorResults
from extractors.segment_selector.get_data_for_performance import get_data_for_performance

extraction_identifier = ExtractionIdentifier(run_name="benchmark", extraction_name="segment_selector")

METHODS: list[SegmentSelectorBase] = [FastSegmentSelector(extraction_identifier), SegmentSelector(extraction_identifier)]


def get_train_test(pdfs_data):
train_number = int(len(pdfs_data) * 0.5)
train_data = pdfs_data[:train_number]
test_data = pdfs_data[train_number:]
return train_data, test_data


def print_results(results):
results.sort(key=lambda x: (x.dataset, x.method))
for method in METHODS:
precisions = [x.precision for x in results if x.method == method.get_name()]
recalls = [x.recall for x in results if x.method == method.get_name()]
seconds = [x.seconds for x in results if x.method == method.get_name()]
average_precision = round(sum(precisions) / len(precisions), 2)
average_recall = round(sum(recalls) / len(recalls), 2)
average_seconds = round(sum(seconds) / len(seconds))

results.append(
SegmentSelectorResults(
method=method.get_name(),
dataset="Average",
precision=average_precision,
recall=average_recall,
seconds=average_seconds,
)
)

data = [x.model_dump() for x in results]
padding = SegmentSelectorResults.get_padding()
markdown = markdown_table(data).set_params(padding_width=5, padding_weight=padding).get_markdown()
print(markdown)


def get_performance_segment_selector():
data = get_data_for_performance(filter_datasets=[])
print(f"Datasets: {data.keys()}")

results: list[SegmentSelectorResults] = list()
for dataset, pdfs_data in data.items():
training_set, test_set = get_train_test(pdfs_data)

truth = [x.ml_label for pdf_data in test_set for x in pdf_data.pdf_data_segments]

for segment in [x for pdf_data in test_set for x in pdf_data.pdf_data_segments]:
segment.ml_label = 0

for method in METHODS:
method.prepare_model_folder()
start = time()
predicted_labels = method.get_predictions_for_performance(training_set, test_set)

selector_results = SegmentSelectorResults(
method=method.get_name(),
dataset=dataset,
precision=round(100 * precision_score(truth, predicted_labels), 2),
recall=round(100 * recall_score(truth, predicted_labels), 2),
seconds=round(time() - start),
)

results.append(selector_results)

print_results(results)


if __name__ == "__main__":
get_performance_segment_selector()

0 comments on commit 341aab9

Please sign in to comment.