Skip to content

Commit

Permalink
Merge pull request #33 from Princeton-CDH/feature/rm-models
Browse files Browse the repository at this point in the history
Utility script to batch remove models from eScriptorium
  • Loading branch information
rlskoeser authored Dec 5, 2024
2 parents 23bbbd0 + 0fa7bc1 commit 7897950
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 14 deletions.
64 changes: 50 additions & 14 deletions src/htr2hpc/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from time import sleep
import pathlib
from typing import Optional, Any
from urllib.parse import urlparse, parse_qs
import json

import humanize
Expand All @@ -29,6 +30,7 @@ class NotAllowed(Exception):
class ResultsList:
"""API list response."""

api: "eScriptoriumAPIClient"
# all API list methods have the same structure,
# so use a dataclass but specify the result type

Expand All @@ -43,6 +45,18 @@ def __post_init__(self):
# specified result type name
self.results = [to_namedtuple(self.result_type, d) for d in self.results]

def next_page(self):
# if there is a next page of results
if self.next:
# parse the url to get the page number for the next page
next_params = parse_qs(urlparse(self.next).query)
next_page_num = next_params["page"][0]
# convert result type (e.g. "list_model") into api method (model_list)
# (this may be brittle...)
single_rtype = self.result_type.replace("list", "").strip("_")
list_method = getattr(self.api, f"{single_rtype}_list")
return list_method(page=next_page_num)


@dataclass
class Task:
Expand Down Expand Up @@ -164,6 +178,8 @@ def _make_request(

if method == "GET":
session_request = self.session.get
elif method == "DELETE":
session_request = self.session.delete
elif method in ["POST", "PUT"]:
session_request = getattr(self.session, method.lower())
# add post data and files to the request if any are specified
Expand Down Expand Up @@ -209,11 +225,14 @@ def get_current_user(self):
api_url = "users/current/"
return to_namedtuple("user", self._make_request(api_url).json())

def model_list(self):
def model_list(self, page=None):
"""paginated list of models"""
api_url = "models/"
resp = self._make_request(api_url)
return ResultsList(result_type="list_model", **resp.json())
params = None
if page:
params = {"page": page}
resp = self._make_request(api_url, params=params)
return ResultsList(api=self, result_type="list_model", **resp.json())

def model_details(self, model_id):
"""details for a single models"""
Expand All @@ -238,6 +257,14 @@ def model_update(self, model_id: int, model_file: pathlib.Path):
# on successful update, returns the model object
return to_namedtuple("model", resp.json())

def model_delete(self, model_id: int):
"""Delete an existing model record from eScriptorum."""
api_url = f"models/{model_id}/"
# eScriptorium returns a 204 No Content response on success
self._make_request(
api_url, method="DELETE", expected_status=requests.codes.no_content
)

def model_create(
self,
model_file: pathlib.Path,
Expand Down Expand Up @@ -283,24 +310,30 @@ def get_model_accuracy(self, model_file: pathlib.Path):
meta = json.loads(m.get_spec().description.metadata.userDefined["kraken_meta"])
return meta["accuracy"][-1][-1] * 100

def document_list(self):
def document_list(self, page=None):
"""paginated list of documents"""
api_url = "documents/"
resp = self._make_request(api_url)
return ResultsList(result_type="document", **resp.json())
params = None
if page:
params = {"page": page}
resp = self._make_request(api_url, params=params)
return ResultsList(api=self, result_type="document", **resp.json())

def document_details(self, document_id: int):
"""details for a single document"""
api_url = f"documents/{document_id}/"
resp = self._make_request(api_url)
return to_namedtuple("document", resp.json())

def document_parts_list(self, document_id: int):
def document_parts_list(self, document_id: int, page=None):
"""list of all the parts associated with a document"""
api_url = f"documents/{document_id}/parts/"
resp = self._make_request(api_url)
params = None
if page:
params = {"page": page}
resp = self._make_request(api_url, params=params)
# document part listed here is different than full parts result
return ResultsList(result_type="documentpart", **resp.json())
return ResultsList(api=self, result_type="document_parts", **resp.json())

def document_part_details(self, document_id: int, part_id: int):
"""details for one part of a document"""
Expand All @@ -319,7 +352,7 @@ def document_part_transcription_list(
if transcription_id is not None:
params["transcription"] = transcription_id
resp = self._make_request(api_url, params=params)
return ResultsList(result_type="transcription_line", **resp.json())
return ResultsList(api=self, result_type="transcription_line", **resp.json())

def document_export(
self, document_id: int, transcription_id: int, include_images: bool = False
Expand Down Expand Up @@ -355,7 +388,7 @@ def list_types(self, item):
raise ValueError(f"{item} is not a supported type for list types")
api_url = f"types/{item}/"
resp = self._make_request(api_url)
return ResultsList(result_type="type", **resp.json())
return ResultsList(api=self, result_type="type", **resp.json())

def export_file_url(
self,
Expand Down Expand Up @@ -474,11 +507,14 @@ def download_file(

return None

def task_list(self):
def task_list(self, page=None):
"""paginated list of tasks"""
api_url = "tasks/"
resp = self._make_request(api_url)
return ResultsList(result_type="task", **resp.json())
params = None
if page:
params = {"page": page}
resp = self._make_request(api_url, params=params)
return ResultsList(api=self, result_type="task", **resp.json())

def task_details(self, task_id: int):
"""details for a single task"""
Expand Down
75 changes: 75 additions & 0 deletions src/htr2hpc/train/rm_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Utility script to delete models from eScriptorim API, for easily
cleaning up models created and uploaded for testing.
Requires an eScriptorium API token set as an environment variable in
**ESCRIPTORIUM_API_TOKEN**.
Takes a base url for the eScriptorium instance and a model name prefix; will
delete all models that start with the specified model name prefix.
usage:
python src/htr2hpc/train/rm_models.py https://test-htr.lib.princeton.edu/ model_prefix
"""

import os
import argparse

from tqdm import tqdm

from htr2hpc.api_client import eScriptoriumAPIClient

api_token_env_var = "ESCRIPTORIUM_API_TOKEN"


def main():
try:
api_token = os.environ[api_token_env_var]
except KeyError:
print(
f"Error: eScriptorium API token must be set as environment variable {api_token_env_var}",
file=sys.stderr,
)
sys.exit(1)

parser = argparse.ArgumentParser(description="Remove models from eScriptorium API")
parser.add_argument(
"base_url",
metavar="BASE_URL",
help="Base URL for eScriptorium instance (without /api/)",
type=str,
)

parser.add_argument(
"model_name",
help="Model name to remove",
type=str,
)
args = parser.parse_args()

api = eScriptoriumAPIClient(args.base_url, api_token=api_token)
model_list = api.model_list()
rm_models = []
# filter by match on name
while model_list.next:
rm_models.extend(
[
model
for model in model_list.results
if model.name.startswith(args.model_name)
]
)
model_list = model_list.next_page()

if rm_models:
print(f"Removing {len(rm_models)} models")
for model in tqdm(rm_models, desc="Removing models: "):
api.model_delete(model.pk)
else:
print("No matching models")


if __name__ == "__main__":
main()

0 comments on commit 7897950

Please sign in to comment.