Skip to content

Commit

Permalink
Merge pull request #86 from huridocs/t5-not-defaulting
Browse files Browse the repository at this point in the history
Remove T5 as default if the other methods are not good
  • Loading branch information
gabriel-piles authored Sep 9, 2024
2 parents 12a3ac2 + f7cf8ef commit 51a30f6
Showing 1 changed file with 32 additions and 53 deletions.
85 changes: 32 additions & 53 deletions src/extractors/text_to_text_extractor/TextToTextExtractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
from os.path import exists, join
from pathlib import Path

from config import config_logger
from data.ExtractionData import ExtractionData
from data.ExtractionIdentifier import ExtractionIdentifier
from data.PredictionSample import PredictionSample
from data.Suggestion import Suggestion
from extractors.ExtractorBase import ExtractorBase
Expand All @@ -29,29 +32,45 @@ class TextToTextExtractor(ExtractorBase):
MT5TrueCaseEnglishSpanishMethod,
]

def __init__(self, extraction_identifier: ExtractionIdentifier):
super().__init__(extraction_identifier)
extractor_path = self.extraction_identifier.get_path()
self.method_name_path = Path(extractor_path, "text_to_text_extractor", "method_name.json")
os.makedirs(self.method_name_path.parent, exist_ok=True)

def get_suggestions(self, predictions_samples: list[PredictionSample]) -> list[Suggestion]:
method_instance = self.get_predictions_method()
message = f"Predicting {len(predictions_samples)} documents with {method_instance.get_name()}"
send_logs(self.extraction_identifier, message)

prediction = method_instance.predict(predictions_samples)

suggestions = list()
for prediction, prediction_sample in zip(prediction, predictions_samples):
entity_name = prediction_sample.entity_name
suggestions.append(Suggestion.from_prediction_text(self.extraction_identifier, entity_name, prediction))

return suggestions

def get_predictions_method(self) -> TextToTextMethod:
if not self.method_name_path.exists():
return self.METHODS[0](self.extraction_identifier)

method_name = json.loads(self.method_name_path.read_text())
for method in self.METHODS:
method_instance = method(self.extraction_identifier)
method_path = join(self.extraction_identifier.get_path(), method_instance.get_name())
config_logger.info(f"Checking {method_path}")

if exists(method_path):
send_logs(
self.extraction_identifier,
f"Predicting {len(predictions_samples)} documents with {method_instance.get_name()}",
)
return self.suggestions_from_predictions(method_instance, predictions_samples)
if method_instance.get_name() == method_name:
return method_instance

send_logs(self.extraction_identifier, f"Predicting {len(predictions_samples)} documents with SameInputOutputMethod")
naive_method = self.METHODS[0](self.extraction_identifier)
return self.suggestions_from_predictions(naive_method, predictions_samples)
return self.METHODS[0](self.extraction_identifier)

def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:
if len(extraction_data.samples) < 2:
config_logger.info("\nBest method SameInputOutputMethod because no samples")
return True, ""

best_method_instance = self.get_best_method(extraction_data)
self.method_name_path.write_text(json.dumps(best_method_instance.get_name()))
best_method_instance.train(extraction_data)
return True, ""

Expand All @@ -63,7 +82,7 @@ def remove_models(self):
def get_best_method(self, extraction_data: ExtractionData):
best_performance = 0
best_method_instance = self.METHODS[0](self.extraction_identifier)
for method in self.METHODS[:-1]:
for method in self.METHODS:
method_instance = method(self.extraction_identifier)
send_logs(self.extraction_identifier, f"Checking {method_instance.get_name()}")
performance = method_instance.performance(extraction_data)
Expand All @@ -76,48 +95,8 @@ def get_best_method(self, extraction_data: ExtractionData):
best_performance = performance
best_method_instance = method_instance

return self.decide_best_method_or_t5(best_performance, best_method_instance, extraction_data)

def decide_best_method_or_t5(
self,
best_performance: float,
best_method_instance: TextToTextMethod,
extraction_data: ExtractionData,
):
if best_performance > 85:
send_logs(self.extraction_identifier, f"Best method {best_method_instance.get_name()} with {best_performance}%")
return best_method_instance

t5 = MT5TrueCaseEnglishSpanishMethod(self.extraction_identifier)

if best_performance < 60:
send_logs(self.extraction_identifier, f"Best method {t5.get_name()} because the others were bad")
return t5

performance, _ = t5.performance(extraction_data)
send_logs(self.extraction_identifier, f"Performance {t5.get_name()} with {performance}%")

if performance > best_performance:
send_logs(self.extraction_identifier, f"Best method {t5.get_name()} with {performance}%")
return t5

send_logs(self.extraction_identifier, f"Best method {best_method_instance.get_name()} with {best_performance}%")
return best_method_instance

def suggestions_from_predictions(
self, method_instance: type[TextToTextMethod], predictions_samples: list[PredictionSample]
) -> list[Suggestion]:
suggestions = list()
prediction = method_instance.predict(predictions_samples)

for prediction, prediction_sample in zip(prediction, predictions_samples):
suggestion = Suggestion.from_prediction_text(
self.extraction_identifier, prediction_sample.entity_name, prediction
)
suggestions.append(suggestion)

return suggestions

def can_be_used(self, extraction_data: ExtractionData) -> bool:
for sample in extraction_data.samples:
if sample.tags_texts:
Expand Down

0 comments on commit 51a30f6

Please sign in to comment.