Skip to content

Commit

Permalink
Merge pull request #94 from huridocs/fix-few-samples-pdf-to-text
Browse files Browse the repository at this point in the history
Fix few samples pdf to text
  • Loading branch information
gabriel-piles authored Oct 11, 2024
2 parents 4d8f28e + acb689f commit 173ae5a
Show file tree
Hide file tree
Showing 61 changed files with 680 additions and 185 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ services:
pdf_metadata_extraction_worker:
container_name: pdf_metadata_extraction_worker
init: true
entrypoint: [ "python3", "-m", "src.QueueProcessor" ]
entrypoint: [ "python3", "-m", "src.start_queue_processor" ]
restart: unless-stopped
build:
context: .
Expand Down
3 changes: 2 additions & 1 deletion gpu-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
env_file: .env
pdf_metadata_extraction_worker:
container_name: pdf_metadata_extraction_worker
restart: unless-stopped
deploy:
resources:
reservations:
Expand All @@ -28,7 +29,7 @@ services:
count: 1
capabilities: [ gpu ]
init: true
entrypoint: [ "python", "-m", "src.QueueProcessor" ]
entrypoint: [ "python", "-m", "src.start_queue_processor" ]
build:
context: .
dockerfile: Dockerfile
Expand Down
3 changes: 2 additions & 1 deletion local-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ services:
env_file: .env.local
pdf_metadata_extraction_worker:
container_name: pdf_metadata_extraction_worker
restart: unless-stopped
init: true
entrypoint: [ "python", "-m", "src.QueueProcessor" ]
entrypoint: [ "python", "-m", "src.start_queue_processor" ]
build:
context: .
dockerfile: Dockerfile
Expand Down
3 changes: 2 additions & 1 deletion local-gpu-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
env_file: .env.local
pdf_metadata_extraction_worker:
container_name: pdf_metadata_extraction_worker
restart: unless-stopped
deploy:
resources:
reservations:
Expand All @@ -28,7 +29,7 @@ services:
count: 1
capabilities: [ gpu ]
init: true
entrypoint: [ "python", "-m", "src.QueueProcessor" ]
entrypoint: [ "python", "-m", "src.start_queue_processor" ]
build:
context: .
dockerfile: Dockerfile
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
git+https://github.com/huridocs/pdf-document-layout-analysis@15b7116c10baa62d7278b870b6b87cd1f742695f
git+https://github.com/huridocs/queue-processor@bab1f4419b0768df518d06795afd5df2ba0e331c
git+https://github.com/huridocs/pdf-document-layout-analysis@67365bb133dab5826a3863d2e2cc551a21b89e81
git+https://github.com/huridocs/queue-processor@1875372bf9f6dcd1995a32c4e50ff92aa45f9ea8
slugify==0.0.1
python-Levenshtein==0.25.1
tdda==2.0.9
Expand All @@ -20,6 +20,6 @@ rapidfuzz==3.8.1
sentry_sdk==1.44.0
pymongo==4.6.3
graypy==2.1.0
setfit==1.0.3
setfit==1.1.0
fuzzywuzzy==0.18.0
httpx==0.27.0
12 changes: 5 additions & 7 deletions src/FilterValidSegmentsPages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from os.path import join, exists

from pathlib import Path

from data.ExtractionIdentifier import ExtractionIdentifier
from data.LabeledData import LabeledData
Expand All @@ -13,7 +13,7 @@

class FilterValidSegmentsPages:
def __init__(self, extraction_identifier: ExtractionIdentifier):
self.labeled_data_json_path = join(extraction_identifier.get_path(), "filter_pages.json")
self.labeled_data_json_path = Path(extraction_identifier.get_path(), "filter_pages.json")
self.start_gaps = []
self.end_gaps = []
self.valid_pages_ranges = []
Expand All @@ -37,12 +37,10 @@ def get_valid_pages(self, total_number_pages_per_document: list[int]) -> list[li
return valid_page_numbers_from_the_end

def for_training(self, labeled_data_list: list[LabeledData]):
if not exists(os.path.dirname(self.labeled_data_json_path)):
os.makedirs(os.path.dirname(self.labeled_data_json_path))

with open(self.labeled_data_json_path, "w") as file:
json.dump([x.model_dump_json() for x in labeled_data_list], file)
if not self.labeled_data_json_path.parent.exists():
os.makedirs(self.labeled_data_json_path.parent)

self.labeled_data_json_path.write_text(json.dumps([x.model_dump_json() for x in labeled_data_list]))
self.set_parameters(labeled_data_list)
pages_list = [[x.page_number for x in labeled_data.xml_segments_boxes] for labeled_data in labeled_data_list]
total_number_pages_per_document = [max(pages) if pages else 1000 for pages in pages_list]
Expand Down
4 changes: 4 additions & 0 deletions src/data/PdfDataSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ def from_list_to_merge(pdf_segments_to_merge: list["PdfDataSegment"]):
@staticmethod
def from_texts(texts: list[str]):
return [PdfDataSegment(i + 1, Rectangle(0, 0, 0, 0), text) for i, text in enumerate(texts)]

@staticmethod
def create_with_text(text: str):
return PdfDataSegment(0, Rectangle(0, 0, 0, 0), text)
5 changes: 4 additions & 1 deletion src/data/PredictionSample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass

from data.PdfData import PdfData
from data.PdfDataSegment import PdfDataSegment


@dataclass
Expand All @@ -22,7 +23,9 @@ def from_pdf_data(pdf_data: PdfData):

@staticmethod
def from_text(text: str, entity_name: str = ""):
return PredictionSample(tags_texts=[text], entity_name=entity_name)
pdf_data = PdfData(None)
pdf_data.pdf_data_segments.append(PdfDataSegment.create_with_text(text))
return PredictionSample(tags_texts=[text], entity_name=entity_name, pdf_data=pdf_data)

@staticmethod
def from_texts(texts: list[str]):
Expand Down
8 changes: 1 addition & 7 deletions src/extractors/ExtractorBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def is_multilingual(multi_option_data: ExtractionData) -> bool:
return False

@staticmethod
def get_train_test_sets(
extraction_data: ExtractionData, limit_samples: bool = False
) -> (ExtractionData, ExtractionData):
def get_train_test_sets(extraction_data: ExtractionData) -> (ExtractionData, ExtractionData):
if len(extraction_data.samples) < 8:
return extraction_data, extraction_data

Expand All @@ -57,10 +55,6 @@ def get_train_test_sets(
else:
test_set = extraction_data.samples[train_size:]

if limit_samples:
train_set = train_set[:80]
test_set = test_set[:30]

train_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, train_set)
test_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, test_set)
return train_extraction_data, test_extraction_data
Expand Down
33 changes: 24 additions & 9 deletions src/extractors/ToTextExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from extractors.text_to_text_extractor.methods.SameInputOutputMethod import SameInputOutputMethod
from send_logs import send_logs

RETRAIN_SAMPLES_THRESHOLD = 250


class ToTextExtractor(ExtractorBase):
METHODS: list[type[ToTextExtractorMethod]] = []
Expand Down Expand Up @@ -68,9 +70,11 @@ def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:
if not extraction_data.samples:
return False, "No samples to create model"

performance_train_set, performance_test_set = ExtractorBase.get_train_test_sets(extraction_data)
send_logs(self.extraction_identifier, f"Train set contains {len(performance_train_set.samples)} samples")
send_logs(self.extraction_identifier, f"Test set contains {len(performance_test_set.samples)} samples")
performance_train_set, performance_test_set = self.get_train_test_sets(extraction_data)

samples_info = f"Train: {len(performance_train_set.samples)} samples\n"
samples_info += f"Test: {len(performance_test_set.samples)} samples"
send_logs(self.extraction_identifier, samples_info)

if len(extraction_data.samples) < 2:
best_method_instance = self.METHODS[0](self.extraction_identifier)
Expand All @@ -80,23 +84,34 @@ def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:

best_method_instance = self.get_best_method(extraction_data)
self.method_name_path.write_text(json.dumps(best_method_instance.get_name()))
best_method_instance.train(extraction_data)

if len(extraction_data.samples) < RETRAIN_SAMPLES_THRESHOLD:
best_method_instance.train(extraction_data)

self.remove_data_from_methods_not_selected(best_method_instance)

return True, ""

def remove_models(self):
for method in self.METHODS:
method_instance = method(self.extraction_identifier)
method_instance.remove_model()
@staticmethod
def get_train_test_sets(extraction_data: ExtractionData) -> (ExtractionData, ExtractionData):
return ExtractorBase.get_train_test_sets(extraction_data)

def remove_data_from_methods_not_selected(self, best_method_instance):
for method_to_remove in self.METHODS:
method_instance = method_to_remove(self.extraction_identifier)
if method_instance.get_name() != best_method_instance.get_name():
method_instance.remove_method_data()

def get_best_method(self, extraction_data: ExtractionData):
best_performance = 0
best_method_instance = self.METHODS[0](self.extraction_identifier)
performance_log = "Performance aggregation:\n"

training_set, test_set = self.get_train_test_sets(extraction_data)
for method in self.METHODS:
method_instance = method(self.extraction_identifier)
send_logs(self.extraction_identifier, f"Checking {method_instance.get_name()}")
performance = method_instance.performance(extraction_data)
performance = method_instance.performance(training_set, test_set)
performance_log += f"{method_instance.get_name()}: {round(performance, 2)}%\n"
send_logs(self.extraction_identifier, f"Performance {method_instance.get_name()}: {performance}%")
if performance == 100:
Expand Down
36 changes: 23 additions & 13 deletions src/extractors/ToTextExtractorMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,54 @@
import os
import shutil
from abc import abstractmethod
from copy import deepcopy
from os.path import join, exists
from pathlib import Path

from config import config_logger
from data.ExtractionData import ExtractionData
from data.ExtractionIdentifier import ExtractionIdentifier
from data.PredictionSample import PredictionSample
from extractors.ExtractorBase import ExtractorBase
from extractors.pdf_to_multi_option_extractor.filter_segments_methods.CleanBeginningDot250 import CleanBeginningDot250


class ToTextExtractorMethod:

def __init__(self, extraction_identifier: ExtractionIdentifier):
def __init__(self, extraction_identifier: ExtractionIdentifier, from_class_name: str = ""):
self.from_class_name = from_class_name
self.extraction_identifier = extraction_identifier
os.makedirs(self.extraction_identifier.get_path(), exist_ok=True)

def get_path(self):
if self.from_class_name:
path = join(self.extraction_identifier.get_path(), self.from_class_name, self.get_name())
else:
path = join(self.extraction_identifier.get_path(), self.get_name())

os.makedirs(path, exist_ok=True)
return path

def get_name(self):
return self.__class__.__name__

def save_json(self, file_name: str, data: any):
path = join(self.extraction_identifier.get_path(), self.get_name(), file_name)
path = join(self.get_path(), file_name)

if not exists(Path(path).parent):
os.makedirs(Path(path).parent)

with open(path, "w") as file:
json.dump(data, file)

def load_json(self, file_name: str):
path = join(self.extraction_identifier.get_path(), self.get_name(), file_name)
path = join(self.get_path(), file_name)

if not exists(path):
return ""

with open(path, "r") as file:
return json.load(file)

def remove_model(self):
shutil.rmtree(join(self.extraction_identifier.get_path(), self.get_name()), ignore_errors=True)

@abstractmethod
def train(self, extraction_data: ExtractionData):
pass
Expand All @@ -54,22 +62,21 @@ def predict(self, predictions_samples: list[PredictionSample]) -> list[str]:
def clean_text(text: str) -> str:
return " ".join(text.split())

def performance(self, extraction_data: ExtractionData) -> float:
if not extraction_data.samples:
def performance(self, performance_train_set: ExtractionData, performance_test_set: ExtractionData) -> float:
if not performance_train_set.samples:
return 0

performance_train_set, performance_test_set = ExtractorBase.get_train_test_sets(extraction_data)

self.train(performance_train_set)
samples = performance_test_set.samples
predictions = self.predict([PredictionSample(pdf_data=x.pdf_data, tags_texts=x.tags_texts) for x in samples])
predictions = self.predict(
[PredictionSample(pdf_data=deepcopy(x.pdf_data), tags_texts=x.tags_texts) for x in samples]
)

correct = [
sample
for sample, prediction in zip(performance_test_set.samples, predictions)
if self.clean_text(sample.labeled_data.label_text) == self.clean_text(prediction)
]
self.remove_model()
return 100 * len(correct) / len(performance_test_set.samples)

def log_performance_sample(self, extraction_data: ExtractionData, predictions: list[str]):
Expand All @@ -87,3 +94,6 @@ def log_performance_sample(self, extraction_data: ExtractionData, predictions: l
message += f"document text : {document_text[:70].strip()}\n"

config_logger.info(message)

def remove_method_data(self):
shutil.rmtree(self.get_path(), ignore_errors=True)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


class FilterSegmentsMethod(ABC):
def get_name(self):
return self.__class__.__name__

@abstractmethod
def filter_segments(self, pdf_data_segments: list[PdfDataSegment]) -> list[PdfDataSegment]:
pass
Expand Down
28 changes: 18 additions & 10 deletions src/extractors/pdf_to_multi_option_extractor/MultiLabelMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,25 @@


class MultiLabelMethod(ABC):
def __init__(self, extraction_identifier: ExtractionIdentifier, options: list[Option], multi_value: bool):
def __init__(
self, extraction_identifier: ExtractionIdentifier, options: list[Option], multi_value: bool, method_name: str = ""
):
self.method_name = method_name
self.extraction_identifier = extraction_identifier
self.options = options
self.multi_value = multi_value
self.base_path = extraction_identifier.get_path()

if not exists(self.base_path):
os.makedirs(self.base_path)
def get_name(self):
return self.__class__.__name__

def get_path(self):
if self.method_name:
path = join(self.extraction_identifier.get_path(), self.method_name)
else:
path = join(self.extraction_identifier.get_path(), self.get_name())

os.makedirs(path, exist_ok=True)
return path

@abstractmethod
def train(self, multi_option_data: ExtractionData):
Expand All @@ -33,25 +44,22 @@ def train(self, multi_option_data: ExtractionData):
def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:
pass

def get_name(self):
return self.__class__.__name__

def save_json(self, file_name: str, data: any):
path = join(self.base_path, self.get_name(), file_name)
path = join(self.get_path(), file_name)
if not exists(Path(path).parent):
makedirs(Path(path).parent)

with open(path, "w") as file:
json.dump(data, file)

def load_json(self, file_name: str):
path = join(self.base_path, self.get_name(), file_name)
path = join(self.get_path(), file_name)

with open(path, "r") as file:
return json.load(file)

def remove_model(self):
shutil.rmtree(join(self.base_path, self.get_name()), ignore_errors=True)
shutil.rmtree(join(self.get_path()), ignore_errors=True)

def get_texts_labels(self, multi_option_data: ExtractionData) -> (list[str], list[list[int]]):
texts = list()
Expand Down
Loading

0 comments on commit 173ae5a

Please sign in to comment.