Skip to content

Commit

Permalink
Merge pull request #31 from alan-turing-institute/27-integrate-traine…
Browse files Browse the repository at this point in the history
…d-classifier

added functionality for loading pre-trained classification models
  • Loading branch information
eddableheath authored Dec 6, 2024
2 parents 0b79ed9 + 0f704b6 commit af48055
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 88 deletions.
14 changes: 14 additions & 0 deletions config/RTC_configs/roberta-mt5-trained.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

translator:
specific_task: "translation_fr_to_en"
model: "ybanas/autotrain-fr-en-translate-51410121895"

classifier:
specific_task: "text-classification"
model: "../distilbert-topic-classifier/"
kwargs:
truncation : True
padding : True
13 changes: 0 additions & 13 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
"""
Steps:
- Load data
- Load pipeline/model
- Run inference on all test data
- Save outputs of specified model (on clean data)
- Calculate error of specified model (on clean data)
- Save results
- File structure:
- output/check_callibration/pipeline_name/run_[X]/[OUTPUT FILES HERE]
"""

import json
import os

Expand Down
2 changes: 1 addition & 1 deletion scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main(rtc_pars):
rtc_variational_pipeline = RTCVariationalPipeline(rtc_pars, metadata_params)

# check dropout exists
rtc_variational_pipeline.check_dropout()
rtc_variational_pipeline.check_dropout(rtc_variational_pipeline.pipeline_map)

# perform variational inference
clean_output, var_output = rtc_variational_pipeline.variational_inference(test_row)
Expand Down
7 changes: 3 additions & 4 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
[
"clean_scores",
"mean_scores",
"hamming_accuracy",
"hamming_loss",
"mean_predicted_entropy",
],
)
Expand Down Expand Up @@ -126,12 +126,12 @@ def classification_results(
clean_scores: torch.Tensor = clean_output["classification"]["scores"]
preds = torch.round(mean_scores).tolist()
labels = self.multihot(test_row["labels"])
hamming_acc = hamming_loss(y_pred=preds, y_true=labels)
hmng_loss = hamming_loss(y_pred=preds, y_true=labels)

return ClassificationResults(
mean_scores=mean_scores.detach().tolist(),
hamming_loss=hmng_loss,
clean_scores=clean_scores,
hamming_accuracy=hamming_acc,
mean_predicted_entropy=torch.mean(
var_output["classification"]["predicted_entropy"]
).item(),
Expand All @@ -152,5 +152,4 @@ def run_inference(
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
break
return results
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
import transformers
from transformers import pipeline

from arc_spice.variational_pipelines.RTC_variational_pipeline import (
Expand All @@ -10,6 +11,7 @@
CustomTranslationPipeline,
dropout_off,
dropout_on,
set_classifier,
set_dropout,
)

Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(
**kwargs,
):
self.set_device()
self.ocr = pipeline(
self.ocr: transformers.Pipeline = pipeline(
task=model_pars["ocr"]["specific_task"],
model=model_pars["ocr"]["model"],
device=self.device,
Expand Down Expand Up @@ -125,7 +127,7 @@ def __init__(
n_variational_runs=n_variational_runs,
translation_batch_size=translation_batch_size,
)
self.translator = pipeline(
self.translator: transformers.Pipeline = pipeline(
task=model_pars["translator"]["specific_task"],
model=model_pars["translator"]["model"],
max_length=512,
Expand All @@ -151,25 +153,24 @@ def __init__(
n_variational_runs=5,
**kwargs,
):
self.set_device()
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
zero_shot = True
else:
zero_shot = False
super().__init__(
step_name="classification",
input_key="target_text",
forward_function=self.classify_topic,
forward_function=self.classify_topic_zero_shot
if zero_shot
else self.classify_topic,
confidence_function=self.get_classification_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
self.classifier = pipeline(
task=model_pars["classifier"]["specific_task"],
model=model_pars["classifier"]["model"],
multi_label=True,
device=self.device,
self.classifier: transformers.Pipeline = set_classifier(
model_pars["classifier"], self.device
)
self.model = self.classifier.model
# topic description labels for the classifier
self.topic_labels = [
class_names_dict["en"]
for class_names_dict in data_pars["class_descriptors"]
]
self.dataset_meta_data = data_pars
self._init_pipeline_map()
39 changes: 21 additions & 18 deletions src/arc_spice/variational_pipelines/RTC_variational_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RTCVariationalPipelineBase,
dropout_off,
dropout_on,
set_classifier,
set_dropout,
)

Expand Down Expand Up @@ -38,7 +39,12 @@ def __init__(
n_variational_runs=5,
translation_batch_size=16,
) -> None:
super().__init__(n_variational_runs, translation_batch_size)
# are we doing zero-shot-classification?
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
self.zero_shot = True
else:
self.zero_shot = False
super().__init__(self.zero_shot, n_variational_runs, translation_batch_size)
# defining the pipeline objects
self.ocr = pipeline(
task=model_pars["ocr"]["specific_task"],
Expand All @@ -52,18 +58,9 @@ def __init__(
pipeline_class=CustomTranslationPipeline,
device=self.device,
)
self.classifier = pipeline(
task=model_pars["classifier"]["specific_task"],
model=model_pars["classifier"]["model"],
multi_label=True,
device=self.device,
)
# topic description labels for the classifier
self.topic_labels = [
class_names_dict["en"]
for class_names_dict in data_pars["class_descriptors"]
]

self.classifier = set_classifier(model_pars["classifier"], self.device)
# topic meta_data for the classifier
self.dataset_meta_data = data_pars
self._init_semantic_density()
self._init_pipeline_map()

Expand All @@ -83,9 +80,15 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:
clean_output["translation"] = self.translate(
clean_output["recognition"]["outputs"]
)
clean_output["classification"] = self.classify_topic(
clean_output["translation"]["outputs"][0]
)
# we now need to pass the input correct to the correct forward method
if self.zero_shot:
clean_output["classification"] = self.classify_topic_zero_shot(
clean_output["translation"]["outputs"][0]
)
else:
clean_output["classification"] = self.classify_topic(
clean_output["translation"]["outputs"][0]
)
return clean_output

def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
Expand All @@ -110,15 +113,15 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
# for each model in pipeline
for model_key, pl in self.pipeline_map.items():
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr]
set_dropout(model=pl.model, dropout_flag=True) # type: ignore[union-attr,attr-defined]
torch.nn.functional.dropout = dropout_on
# do n runs of the inference
for run_idx in range(self.n_variational_runs):
var_output[model_key][run_idx] = self.func_map[model_key](
input_map[model_key]
)
# turn off dropout for this model
set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr]
set_dropout(model=pl.model, dropout_flag=False) # type: ignore[union-attr,attr-defined]
torch.nn.functional.dropout = dropout_off

# run metric helper functions
Expand Down
87 changes: 77 additions & 10 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,52 @@
from typing import Any

import torch
import transformers
from torch.nn.functional import softmax
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Pipeline,
TranslationPipeline,
pipeline,
)

logger = logging.Logger("RTC_variational_pipeline")

# Some methods for the


def collate_scores(
scores: list[dict[str, float]], label_order
) -> dict[str, list | dict]:
# these need to be returned in original order
# return dict for to guarantee class predictions can be recovered
collated = {score["label"]: score["score"] for score in scores}
return {
"scores": [collated[label] for label in label_order],
"score_dict": collated,
}


def set_classifier(classifier_pars: dict, device: str) -> transformers.Pipeline:
# new helper function which given the classifier parameters sets the correct
# pipeline method. This is needed because they take different kwargs
# > THIS COULD BE REFACTORED BY PUTTING KWARGS IN THE CONFIG <
if classifier_pars["specific_task"] == "zero-shot-classification":
return pipeline(
task=classifier_pars["specific_task"],
model=classifier_pars["model"],
multi_label=True,
device=device,
**classifier_pars.get("kwargs", {}),
)
return pipeline(
task=classifier_pars["specific_task"],
model=classifier_pars["model"],
device=device,
**classifier_pars.get("kwargs", {}),
)


def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
"""
Expand Down Expand Up @@ -104,7 +140,7 @@ def clean_inference(self, x):
def variational_inference(self, x):
pass

def __init__(self, n_variational_runs=5, translation_batch_size=8):
def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size=8):
# device for inference
self.set_device()
debug_msg_device = f"Loading pipeline on device: {self.device}"
Expand All @@ -113,7 +149,9 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
self.func_map = {
"recognition": self.recognise,
"translation": self.translate,
"classification": self.classify_topic,
"classification": self.classify_topic_zero_shot
if zero_shot
else self.classify_topic,
}
# the naive outputs of the pipeline stages calculated in self.clean_inference
self.naive_outputs = {
Expand All @@ -139,8 +177,10 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
self.classifier = None

# map pipeline names to their pipeline counterparts

self.topic_labels = None # This should be defined in subclass if needed
# to replace class descriptors, we now want class descriptors and the labels
self.dataset_meta_data: dict = {
None: None
} # This should be defined in subclass if needed

def _init_pipeline_map(self):
"""
Expand Down Expand Up @@ -193,15 +233,16 @@ def split_translate_inputs(text: str, split_key: str) -> list[str]:
split_rows = split_rows[:-1]
return [split + split_key for split in split_rows]

def check_dropout(self):
@staticmethod
def check_dropout(pipeline_map: transformers.Pipeline):
"""
Checks the existence of dropout layers in the models of the pipeline.
Raises:
ValueError: Raised when no dropout layers are found.
"""
logger.debug("\n\n------------------ Testing Dropout --------------------")
for model_key, pl in self.pipeline_map.items():
for model_key, pl in pipeline_map.items():
# only test models that exist
if pl is None:
pipeline_none_msg_key = (
Expand Down Expand Up @@ -288,15 +329,41 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
# {full translation, sentence translations, logits, semantic embeddings}
return outputs

def classify_topic(self, text: str) -> dict[str, str]:
def classify_topic(self, text: str) -> dict[str, list[float] | dict]:
"""
Runs the classification model
Returns:
Dictionary of classification outputs, namely the output scores.
Dictionary of classification outputs, namely the output scores and
label:score dictionary.
"""
forward = self.classifier(text, top_k=None) # type: ignore[misc]
return collate_scores(forward, self.dataset_meta_data["class_labels"]) # type: ignore[index]

def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]:
"""
Runs the zero-shot classification model
Returns:
Dictionary of classification outputs, namely the output scores and
label:score dictionary.
"""
forward = self.classifier(text, self.topic_labels) # type: ignore[misc]
return {"scores": forward["scores"]}
labels = [
descriptors["en"]
for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index]
]
forward = self.classifier( # type: ignore[misc]
text, labels
)
return collate_scores(
[
{"label": label, "score": score}
for label, score in zip(
forward["labels"], forward["scores"], strict=True
)
],
label_order=labels,
)

def stack_translator_sentence_metrics(
self, all_sentence_metrics: list[dict[str, Any]]
Expand Down
Loading

0 comments on commit af48055

Please sign in to comment.