Skip to content

Commit

Permalink
Merge pull request #35 from Princeton-CDH/feature/recognition-data
Browse files Browse the repository at this point in the history
Update htr2hpc-train to support recognition training
  • Loading branch information
rlskoeser authored Dec 13, 2024
2 parents 7897950 + d97af10 commit 6dece38
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 79 deletions.
24 changes: 13 additions & 11 deletions src/htr2hpc/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@ def __post_init__(self):
def next_page(self):
# if there is a next page of results
if self.next:
# parse the url to get the page number for the next page
next_params = parse_qs(urlparse(self.next).query)
next_page_num = next_params["page"][0]
# convert result type (e.g. "list_model") into api method (model_list)
# (this may be brittle...)
single_rtype = self.result_type.replace("list", "").strip("_")
list_method = getattr(self.api, f"{single_rtype}_list")
return list_method(page=next_page_num)
# request the next page and return as a
# results list item with the same type and api client as this one
resp = self.api._make_request(self.next)
return ResultsList(
api=self.api, result_type=self.result_type, **resp.json()
)


@dataclass
Expand Down Expand Up @@ -100,6 +98,7 @@ def duration(self) -> datetime.timedelta | None:
class Workflow:
convert: Optional[str] = None
segment: Optional[str] = None
transcribe: Optional[str] = None
# workflow status is only present when a workflow has not run,
# so define a dataclass and make them optional
# to handle missing values
Expand All @@ -123,8 +122,6 @@ def to_namedtuple(name: str, data: Any):
logger.debug(f"Creating namedtuple with name {name}")
nt_class = namedtuple(name, data)
RESULTCLASS_REGISTRY[name] = nt_class
else:
logger.debug(f"Using existing result class for {name}: {nt_class}")
# once we have the class, initialize an instance with the given data
return nt_class(
# convert any nested objects to namedtuple classes
Expand Down Expand Up @@ -171,7 +168,12 @@ def _make_request(
Make a GET request with the configured session. Takes a url
relative to :attr:`api_root` and optional dictionary of parameters for the request.
"""
rqst_url = f"{self.api_root}/{url}"
# support absolute urls for retrieving paged results,
# but only urls within the configured eScriptorium instance
if url.startswith(self.api_root):
rqst_url = url
else:
rqst_url = f"{self.api_root}/{url}"
rqst_opts = {}
if params:
rqst_opts["params"] = params.copy()
Expand Down
144 changes: 113 additions & 31 deletions src/htr2hpc/train/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime
import logging
import pathlib
import subprocess
from typing import Optional
from collections import defaultdict
from tqdm import tqdm

from kraken.containers import BaselineLine, Region, Segmentation
from kraken.lib.arrow_dataset import build_binary_dataset
from tqdm import tqdm

# skip import, syntax error in current kraken
# from kraken.lib.arrow_dataset import build_binary_dataset
Expand All @@ -18,14 +18,40 @@
# get a document part from eS api and convert into kraken objects


def get_transcription_lines(api, document_id, part_id, transcription_id):
# The API could have multiple pages of transcription lines;
# loop until all pages of results are consumed
text_lines = {}
# get the first page of results
transcription_lines = api.document_part_transcription_list(
document_id, part_id, transcription_id
)
while True:
# gather lines of text from the current page
for text_line in transcription_lines.results:
# Each transcription line includes a line id,
# transcription id, and text content.
# Add to dict so we can lookup content by line id
text_lines[text_line.line] = text_line.content
# if there is another page of results, get them
if transcription_lines.next:
transcription_lines = transcription_lines.next_page()
# otherwise, we've hit the end; stop looping
else:
break

return text_lines


def get_segmentation_data(
api, document_details, part_id, image_dir
api, document_details, part_id, image_dir, transcription_id=None
) -> tuple[Segmentation, tuple]:
"""Get a single document part from the eScriptorium API and generate
a kraken segmentation object.
Returns a tuple of the segmentation object and the part details from the API,
which includes image size needed for serialization.
which includes image size needed for serialization. Includes transcription
text when a `transcription_id` is specified.
"""

# document details includes id (pk) and valid line and block types
Expand All @@ -37,15 +63,9 @@ def get_segmentation_data(
part = api.document_part_details(document_id, part_id)

# adapted from escriptorium.app.core.tasks.make_segmentation_training_data
# (additional logic in make_recognition_segmentation )

# TODO: for recognition task, we need to get transcription lines
# from api.document_part_transcription_list
# each one includes a line id, transcription id, and content
# ... need to match up based on line pk

# NOTE: eS celery task training prep only includes regions
# for segmentation, not recognition
# and make_recognition_segmentation
# NOTE: regions are not strictly needed for recognition training,
# but does not seem to hurt to include them

# gather regions in a dictionary keyed on type name for
# the segmentation object (name -> list of regions)
Expand All @@ -63,24 +83,34 @@ def get_segmentation_data(
)
)

# gather base lines
# recognition training requires transcription text content
# if a transcription id is specified, retrieve transcription content
if transcription_id:
text_lines = get_transcription_lines(
api, document_id, part_id, transcription_id
)
else:
text_lines = {}

baselines = [
BaselineLine(
id=line.external_id,
baseline=line.baseline,
boundary=line.mask,
# eScriptorium api returns a single region pk
# eScriptorium api returns a single region pk;
# kraken takes a list of string ids
regions=[region_pk_to_id[line.region]],
# NOTE: eS celery task training prep only includes text
# when generating training data for recognition, not segmentation
# this mirrors the behavior from eS code for export:
# orphan lines have no region
regions=[region_pk_to_id[line.region]] if line.region else None,
# mark as default if type is not in the public list
# db includes more types but they are not marked as public
tags={"type": line_types.get(line.typology, "default")},
# get text transcription content for this line, if available
# (only possible when transcription id is specified)
text=text_lines.get(line.pk),
)
for line in part.lines
]

logger.info(f"Document {document_id} part {part_id}: {len(baselines)} baselines")

logger.info(
Expand Down Expand Up @@ -127,9 +157,15 @@ def serialize_segmentation(segmentation: Segmentation, part):
def compile_data(segmentations, output_dir):
"""Compile a list of kraken segmentation objects into a binary file for
recognition training."""
output_file = output_dir / "dataset.arrow"
# NOTE: get code errors in kraken if the image path is not valid.
# Image path on created segments should be relative to current
# working directory. Must resolve so the kraken binary compile
# function can load image files by path.
output_file = output_dir / "train.arrow"
build_binary_dataset(
files=segmentations, format_type=None, output_file=str(output_file)
files=segmentations,
format_type=None, # None = kraken Segmentation objects
output_file=str(output_file),
)
return output_file

Expand All @@ -147,26 +183,51 @@ def get_model(api, model_id, training_type, output_dir):
return api.download_file(model_info.file, output_dir)


def get_training_data(api, output_dir, document_id, part_ids=None):
def get_document_parts(api, document_id):
part_ids = []
# get first page of results
document_parts = api.document_parts_list(document_id)
while True:
# retrieve part ids from the current page and check for more
part_ids.extend([part.pk for part in document_parts.results])
# if there is another page of results, get it
if document_parts.next:
document_parts = document_parts.next_page()
# otherwise, stop looping
else:
break
return part_ids


def get_training_data(
api, output_dir, document_id, part_ids=None, transcription_id=None
):
# if part ids are not specified, get all parts
if part_ids is None:
doc_parts = api.document_parts_list(document_id)
part_ids = [part.pk for part in doc_parts.results]
part_ids = get_document_parts(api, document_id)

# document details includes line and block types
document_details = api.document_details(document_id)

# get segmentation data for each part of the document that is requested
segmentation_data = [
get_segmentation_data(api, document_details, part_id, output_dir)
get_segmentation_data(
api, document_details, part_id, output_dir, transcription_id
)
for part_id in part_ids
]
# if we're generating alto-xml (i.e., segmentation training data),
# serialize each of the parts we downloaded
[serialize_segmentation(seg, part) for (seg, part) in segmentation_data]

# NOTE: binary compiled data is only supported train and not segtrain
# compiled_data = compile_data(segmentations, output_dir)
# if transcription id is specified, compile as binary dataset
# for recognition training
if transcription_id:
segmentations = [seg for seg, _ in segmentation_data]
compile_data(segmentations, output_dir)

# if no transcription id is specified, then serialize as
# alto-xml for segmentation training
else:
# serialize each of the parts that were downloaded
[serialize_segmentation(seg, part) for (seg, part) in segmentation_data]


def get_best_model(model_dir: pathlib.Path) -> pathlib.Path | None:
Expand Down Expand Up @@ -201,5 +262,26 @@ def upload_models(
return uploaded


def upload_best_model(
api, model_dir: pathlib.Path, model_type: str
) -> Optional[pathlib.Path]:
"""Upload the best model in the specified model directory to eScriptorium
with the specified job type (Segment/Recognize). Returns pathlib.Path
for best model if found and successfully uploaded; otherwise returns None."""
best_model = get_best_model(model_dir)
if best_model:
created = api.model_create(
best_model,
job=model_type,
# strip off _best from file for model name in eScriptorium
model_name=best_model.stem.replace("_best", ""),
)
if created:
return best_model
# TODO: return something different here if model create failed?

return None


# use api.update_model with model id and pathlib.Path to model file
# to update existing model record with new file
21 changes: 12 additions & 9 deletions src/htr2hpc/train/rm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import argparse
import sys

from tqdm import tqdm

Expand Down Expand Up @@ -50,18 +51,20 @@ def main():
args = parser.parse_args()

api = eScriptoriumAPIClient(args.base_url, api_token=api_token)
model_list = api.model_list()
rm_models = []
# filter by match on name
while model_list.next:
# handle one or more pages of results from model list
while True:
model_list = api.model_list()
rm_models.extend(
[
model
for model in model_list.results
if model.name.startswith(args.model_name)
]
[m for m in model_list.results if m.name.startswith(args.model_name)]
)
model_list = model_list.next_page()

# if there is another page of results, get them
if model_list.next:
model_list = model_list.next_page()
# otherwise, we've hit the end; stop looping
else:
break

if rm_models:
print(f"Removing {len(rm_models)} models")
Expand Down
Loading

0 comments on commit 6dece38

Please sign in to comment.