Skip to content

Commit

Permalink
Merge pull request #34 from alan-turing-institute/33-fix-bugs-for-inf…
Browse files Browse the repository at this point in the history
…erence-on-baskerville

33 fix bugs for inference on baskerville
  • Loading branch information
eddableheath authored Dec 18, 2024
2 parents 3f6c10d + 4b950ff commit 91b574e
Show file tree
Hide file tree
Showing 24 changed files with 174 additions and 91 deletions.
2 changes: 1 addition & 1 deletion config/RTC_configs/roberta-mt5-trained.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
3 changes: 1 addition & 2 deletions config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
2 changes: 2 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ lang_pair:
target: "en"

drop_length: 1000

load_ocr_data: True
15 changes: 15 additions & 0 deletions config/experiment/finalised_pipeline_zs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
8 changes: 6 additions & 2 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ def main(
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)

if model_key != "ocr":
data_config["load_ocr_data"] = False

data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "translator":
rtc_single_component_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "classifier":
rtc_single_component_pipeline = ClassificationVariationalPipeline(
Expand Down
20 changes: 13 additions & 7 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,17 @@ def extract_articles(

def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
text_split = text.split()
text_split = [text for text in text_split if text not in ("", " ", None)]
text_split = [text for text in text_split if text not in ("", " ")]
generator = GeneratorFromStrings(text_split, count=len(text_split))
return list(generator)


def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
return {"ocr_images": images, "ocr_targets": targets}
def make_ocr_data(item: LazyRow) -> dict:
try:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
except ValueError:
return {"ocr_data": {"ocr_images": None, "ocr_targets": None}}
return {"ocr_data": {"ocr_images": images, "ocr_targets": targets}}


class TranslationPreProcesser:
Expand Down Expand Up @@ -229,11 +232,14 @@ def load_multieurlex_for_pipeline(
make_ocr_data,
features=datasets.Features(
{
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
"ocr_data": {
"ocr_images": datasets.Sequence(
datasets.Image(decode=True)
),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
},
**feats,
}
),
)

return dataset_dict, meta_data
64 changes: 52 additions & 12 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
RTCVariationalPipeline,
)

RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
RecognitionResults = namedtuple(
"RecognitionResults",
[
"mean_entropy",
"character_error_rate",
"full_output",
"max_scores",
],
)

TranslationResults = namedtuple(
"TranslationResults",
Expand Down Expand Up @@ -71,14 +79,19 @@ def get_results(

def recognition_results(
self,
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
var_output: dict[str, dict],
clean_output: dict,
var_output: dict,
**kwargs,
):
# ### RECOGNITION ###
charerror = ocr_error(clean_output)
charerror = ocr_error(clean_output["recognition"])
confidence = var_output["recognition"]["mean_entropy"]
return RecognitionResults(confidence=confidence, accuracy=charerror)
return RecognitionResults(
mean_entropy=confidence,
character_error_rate=charerror,
max_scores=clean_output["recognition"]["outputs"]["max_scores"],
full_output=clean_output["recognition"]["full_output"],
)

def translation_results(
self,
Expand Down Expand Up @@ -150,13 +163,40 @@ def run_inference(
pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
results_getter: ResultsGetter,
):
type_errors = []
oom_errors = []
results = []
for _, inp in enumerate(tqdm(dataloader)):
clean_out, var_out = pipeline.variational_inference(inp)
row_results_dict = results_getter.get_results(
clean_output=clean_out,
var_output=var_out,
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
# TEMPORARY FIX
try:
clean_out, var_out = pipeline.variational_inference(inp)
row_results_dict = results_getter.get_results(
clean_output=clean_out,
var_output=var_out,
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
# TEMPORARY FIX ->
except TypeError:
type_errors.append(inp["celex_id"])
continue

except torch.cuda.OutOfMemoryError:
oom_errors.append(inp["celex_id"])
continue

except torch.OutOfMemoryError:
oom_errors.append(inp["celex_id"])
continue

print("Skipped following CELEX IDs due to TypeError:")
print(
'"TypeError: Incorrect format used for image. Should be an url linking to'
' an image, a base64 string, a local path, or a PIL image."'
)
print(type_errors)

print("Skipped following CELEX IDs due to torch.cuda.OutOfMemoryError:")
print(oom_errors)

return results
9 changes: 5 additions & 4 deletions src/arc_spice/eval/ocr_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from torchmetrics.text import CharErrorRate

cer = CharErrorRate()


def ocr_error(ocr_output: dict[Any, Any]) -> float:
"""
Expand All @@ -30,7 +32,6 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float:
Returns:
Character error rate across entire output of OCR (float)
"""
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
cer = CharErrorRate()
return cer(preds, targs).item()
preds = [itm["generated_text"].lower() for itm in ocr_output["outputs"]]
targs = [itm["target"].lower() for itm in ocr_output["outputs"]]
return cer(preds, targs).detach().item()
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,23 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
ocr_batch_size=64,
**kwargs,
):
self.set_device()
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
)
self.ocr: transformers.Pipeline = pipeline(
model=model_pars["ocr"]["model"],
device=self.device,
pipeline_class=CustomOCRPipeline,
max_new_tokens=20,
batch_size=ocr_batch_size,
**kwargs,
)
self.model = self.ocr.model
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
self._init_pipeline_map()


Expand All @@ -118,7 +115,6 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
translation_batch_size=4,
**kwargs,
):
self.set_device()
# need to initialise the NLI models in this case
Expand Down
13 changes: 7 additions & 6 deletions src/arc_spice/variational_pipelines/RTC_variational_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:

# run the functions
# UNTIL THE OCR DATA IS AVAILABLE
clean_output["recognition"] = self.recognise(x)
clean_output["recognition"] = self.recognise(x["ocr_data"])

clean_output["translation"] = self.translate(
clean_output["recognition"]["outputs"]
clean_output["recognition"]["full_output"]
)
# we now need to pass the input correct to the correct forward method
if self.zero_shot:
clean_output["classification"] = self.classify_topic_zero_shot(
clean_output["translation"]["outputs"][0]
clean_output["translation"]["full_output"]
)
else:
clean_output["classification"] = self.classify_topic(
clean_output["translation"]["outputs"][0]
clean_output["translation"]["full_output"]
)
return clean_output

Expand All @@ -109,8 +109,8 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
}
# define the input map for brevity in forward pass
input_map = {
"recognition": x,
"translation": clean_output["recognition"]["outputs"],
"recognition": x["ocr_data"],
"translation": clean_output["recognition"]["full_output"],
"classification": clean_output["translation"]["full_output"],
}

Expand All @@ -130,6 +130,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:

# run metric helper functions
var_output = self.stack_variational_outputs(var_output)
var_output = self.get_ocr_confidence(var_output)
var_output = self.translation_semantic_density(
clean_output=clean_output, var_output=var_output
)
Expand Down
Loading

0 comments on commit 91b574e

Please sign in to comment.