From 0f8d1896a94fea477d88e351ec39be20537c5c80 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 16 Jan 2025 09:22:26 -0800 Subject: [PATCH] add docstrings and lint enforcement (#512) --- OCR/benchmark_main.py | 7 +- OCR/ocr/api.py | 74 ++++++++- .../backends/four_point_transform.py | 65 +++++++- .../alignment/backends/image_homography.py | 107 ++++++++++--- .../backends/random_perspective_transform.py | 68 +++++++-- OCR/ocr/services/alignment/image_alignment.py | 28 +++- OCR/ocr/services/batch_metrics.py | 81 ++++++++-- OCR/ocr/services/batch_segmentation.py | 60 ++++++-- OCR/ocr/services/image_ocr.py | 103 +++++++++++-- OCR/ocr/services/image_segmenter.py | 96 +++++++++++- OCR/ocr/services/metrics_analysis.py | 144 +++++++++++++++--- OCR/ocr/services/phdc_converter/builder.py | 100 ++++-------- OCR/ocr/services/phdc_converter/models.py | 34 ++--- .../services/phdc_converter/phdc_converter.py | 8 +- OCR/ocr/services/tesseract_ocr.py | 57 +++++-- OCR/pyproject.toml | 20 +++ OCR/tests/assets/ocr_metrics_test.py | 2 +- 17 files changed, 817 insertions(+), 237 deletions(-) diff --git a/OCR/benchmark_main.py b/OCR/benchmark_main.py index 203d151f..a5b245fe 100644 --- a/OCR/benchmark_main.py +++ b/OCR/benchmark_main.py @@ -46,11 +46,9 @@ def main(): def run_segmentation_and_ocr(args): - """ - Runs segmentation and OCR processing. + """Runs segmentation and OCR processing. Returns OCR results with processing time. """ - model = None if args.model == "tesseract": @@ -70,8 +68,7 @@ def run_segmentation_and_ocr(args): def run_metrics_analysis(args, ocr_results): - """ - Runs metrics analysis based on OCR output and ground truth. + """Runs metrics analysis based on OCR output and ground truth. Uses OCR results to capture time values if available. """ metrics_analysis = BatchMetricsAnalysis(args.output_folder, args.ground_truth_folder, args.csv_output_folder) diff --git a/OCR/ocr/api.py b/OCR/ocr/api.py index 1e9bfb33..8ea4b936 100644 --- a/OCR/ocr/api.py +++ b/OCR/ocr/api.py @@ -1,3 +1,5 @@ +"""Module for FastAPI interface functions.""" + import base64 import uvicorn @@ -34,7 +36,18 @@ ocr = TesseractOCR() -def data_uri_to_image(data_uri: str): +def data_uri_to_image(data_uri: str) -> np.ndarray: + """Converts a base64 encoded data URI to an image, represented as a NumPy array. + + Args: + data_uri (str): The base64 encoded image in data URI format. + + Returns: + np.ndarray: The decoded image in NumPy array format. + + Raises: + HTTPException: If the image decoding fails. + """ try: base64_data = data_uri.split(",")[1] image_data = base64.b64decode(base64_data) @@ -48,18 +61,40 @@ def data_uri_to_image(data_uri: str): ) -def image_to_data_uri(image: np.ndarray): +def image_to_data_uri(image: np.ndarray) -> bytes: + """Converts an image to a base64 encoded data URI. + + Args: + image (np.ndarray): The input image in NumPy array format. + + Returns: + bytes: The Base64 encoded data URI representation of the image. + """ _, encoded = cv.imencode(".png", image) return b"data:image/png;base64," + base64.b64encode(encoded) @app.get("/") async def health_check(): + """Health check endpoint to verify the API is running. + + Returns: + dict: A dictionary with the status of the service. + """ return {"status": "UP"} @app.post("/image_alignment/") -async def image_alignment(source_image: str = Form(), segmentation_template: str = Form()): +async def image_alignment(source_image: str = Form(), segmentation_template: str = Form()) -> dict: + """Aligns a source image to a segmentation template. + + Args: + source_image (str): The base64 encoded source image. + segmentation_template (str): The baSe64 encoded segmentation template image. + + Returns: + dict: A dictionary containing the aligned image as a base64 encoded data URI. + """ source_image_img = data_uri_to_image(source_image) segmentation_template_img = data_uri_to_image(segmentation_template) @@ -70,6 +105,19 @@ async def image_alignment(source_image: str = Form(), segmentation_template: str @app.post("/image_file_to_text/") async def image_file_to_text(source_image: UploadFile, segmentation_template: UploadFile, labels: str = Form()): + """Extracts text from an image file based on a segmentation template, using OCR. + + Args: + source_image (UploadFile): The uploaded source image file. + segmentation_template (UploadFile): The uploaded segmentation template file. + labels (str): The JSON-encoded labels defining segmentation templates. + + Returns: + dict: A dictionary containing the OCR results for the segmented regions. + + Raises: + HTTPException: If there are issues with file decoding, parsing, segmentation, or OCR. + """ try: source_image_np = np.frombuffer(await source_image.read(), np.uint8) source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR) @@ -117,7 +165,20 @@ async def image_file_to_text(source_image: UploadFile, segmentation_template: Up @app.post("/image_to_text") async def image_to_text( source_image: str = Form(...), segmentation_template: str = Form(...), labels: str = Form(...) -): +) -> dict: + """Extracts text from an image based on a segmentation template, using OCR. + + Args: + source_image (str): The base64-encoded source image. + segmentation_template (str): The base64-encoded segmentation template. + labels (str): The JSON-encoded labels defining segmentation templates. + + Returns: + dict: A dictionary containing the OCR results for the segmented regions. + + Raises: + HTTPException: If there are issues with file decoding, parsing, segmentation, or OCR. + """ try: source_image_img = data_uri_to_image(source_image) segmentation_template_img = data_uri_to_image(segmentation_template) @@ -149,5 +210,8 @@ async def image_to_text( def start(): - """Launched with `poetry run start` at root level""" + """Starts the FastAPI app. + + Launched with `poetry run start` at root level. + """ uvicorn.run(app, host="0.0.0.0", port=8000, reload=False) diff --git a/OCR/ocr/services/alignment/backends/four_point_transform.py b/OCR/ocr/services/alignment/backends/four_point_transform.py index cb0954b4..e3eeb13c 100644 --- a/OCR/ocr/services/alignment/backends/four_point_transform.py +++ b/OCR/ocr/services/alignment/backends/four_point_transform.py @@ -1,6 +1,4 @@ -""" -Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image. -""" +"""Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.""" from pathlib import Path import functools @@ -10,19 +8,51 @@ class FourPointTransform: + """A class to perform a four-point perspective transformation on an image. + + This involves detecting the largest quadrilateral in the image and transforming it + to a standard rectangular form using a perspective warp. + + Attributes: + image (np.ndarray): The input image as a NumPy array. + """ + def __init__(self, image: Path | np.ndarray): + """Initializes the FourPointTransform object with an image. + + The image can either be provided as a file path (Path) or a NumPy array. + + Args: + image (Path | np.ndarray): The input image, either as a path to a file or as a NumPy array. + """ if isinstance(image, np.ndarray): self.image = image else: self.image = cv.imread(str(image)) @classmethod - def align(self, source_image, template_image): + def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: + """Aligns a source image to a template image using the four-point transform. + + Args: + source_image (np.ndarray): The source image to be aligned. + template_image (np.ndarray): The template image to align to. + + Returns: + np.ndarray: The transformed image. + """ return FourPointTransform(source_image).dewarp() @staticmethod def _order_points(quadrilateral: np.ndarray) -> np.ndarray: - "Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left." + """Reorders the points of a quadrilateral from an unordered 4x2 array to a specific order of top-left, top-right, bottom-right, and bottom-left. + + Args: + quadrilateral (np.ndarray): A 4x2 array representing the vertices of a quadrilateral. + + Returns: + np.ndarray: A 4x2 array with the points ordered as [top-left, top-right, bottom-right, bottom-left]. + """ quadrilateral = quadrilateral.reshape(4, 2) output_quad = np.zeros([4, 2]).astype(np.float32) s = quadrilateral.sum(axis=1) @@ -33,19 +63,38 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray: output_quad[3] = quadrilateral[np.argmax(diff)] return output_quad - def find_largest_contour(self): - """Compute contours for an image and find the biggest one by area.""" + def find_largest_contour(self) -> np.ndarray: + """Finds the largest contour in the image by computing the contours and selecting the one with the greatest area. + + Returns: + np.ndarray: The largest contour found in the image. + """ contours, _ = cv.findContours( cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE ) return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours) def simplify_polygon(self, contour): - """Simplify to a polygon with (hopefully four) vertices.""" + """Simplifies a given contour to a polygon with a reduced number of vertices, ideally four. + + Args: + contour (np.ndarray): The contour to simplify. + + Returns: + np.ndarray: The simplified polygon. + """ perimeter = cv.arcLength(contour, True) return cv.approxPolyDP(contour, 0.01 * perimeter, True) def dewarp(self) -> np.ndarray: + """Performs a four-point perspective transform to "dewarp" the image. + + This involves detecting the largest quadrilateral, simplifying it to a polygon, and + applying a perspective warp to straighten the image into a rectangle. + + Returns: + np.ndarray: The perspective-transformed (dewarped) image. + """ biggest_contour = self.find_largest_contour() simplified = self.simplify_polygon(biggest_contour) diff --git a/OCR/ocr/services/alignment/backends/image_homography.py b/OCR/ocr/services/alignment/backends/image_homography.py index 5ab1ddfb..b001083b 100644 --- a/OCR/ocr/services/alignment/backends/image_homography.py +++ b/OCR/ocr/services/alignment/backends/image_homography.py @@ -1,3 +1,5 @@ +"""Aligns two images using image homography algorithms.""" + from pathlib import Path import numpy as np @@ -5,8 +7,30 @@ class ImageHomography: + """A class to align two images using homography techniques. + + Uses Scale-Invariant Feature Transform (SIFT) algorithm to detect keypoints and + compute descriptors for image matching, and then estimates a homography + transformation matrix to align the source image with a template image. + + Attributes: + template (np.ndarray): The template image to align against, either as a path or a NumPy array. + match_ratio (float): The ratio used for Lowe's ratio test to filter good matches. + _sift (cv.SIFT): The SIFT detector used to find keypoints and descriptors. + """ + def __init__(self, template: Path | np.ndarray, match_ratio=0.3): - """Initialize the image homography pipeline with a `template` image.""" + """Initializes the ImageHomography object with a template image. + + Optionally include a match ratio for filtering descriptor matches; this must be between 0 and 1. + + Args: + template (Path | np.ndarray): The template image, either as a file path or a NumPy array. + match_ratio (float, optional): The ratio threshold for Lowe's ratio test. Default is 0.3. + + Raises: + ValueError: If `match_ratio` is not between 0 and 1. + """ if match_ratio >= 1 or match_ratio <= 0: raise ValueError("`match_ratio` must be between 0 and 1") @@ -18,24 +42,64 @@ def __init__(self, template: Path | np.ndarray, match_ratio=0.3): self._sift = cv.SIFT_create() @classmethod - def align(self, source_image, template_image): + def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: + """Aligns a source image to a template image. + + Args: + source_image (np.ndarray): The source image to align. + template_image (np.ndarray): The template image to align to. + + Returns: + np.ndarray: The aligned source image. + """ return ImageHomography(template_image).transform_homography(source_image) def estimate_self_similarity(self): - """Calibrate `match_ratio` using a self-similarity metric.""" + """Calibrates the match ratio using a self-similarity metric (not implemented). + + Raises: + NotImplementedError: Since this method is not implemented. + """ raise NotImplementedError - def compute_descriptors(self, img): - """Compute SIFT descriptors for a target `img`.""" + def compute_descriptors(self, img: np.ndarray): + """Computes the SIFT descriptors for a given image. + + These descriptors represent distinctive features in the image that can be used for matching. + + Args: + img (np.ndarray): The image for which to compute descriptors. + + Returns: + tuple: A 2-element tuple containing the keypoints and their corresponding descriptors. + """ return self._sift.detectAndCompute(img, None) def knn_match(self, descriptor_template, descriptor_query): - """Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image.""" + """Performs k-nearest neighbors matching (k=2) between descriptors to find best homography matches. + + Args: + descriptor_template (np.ndarray): The SIFT descriptors from the template image. + descriptor_query (np.ndarray): The SIFT descriptors from the query image. + + Returns: + list: A list of k-nearest neighbor matches between the template and query descriptors. + """ matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED) return matcher.knnMatch(descriptor_template, descriptor_query, 2) - def estimate_transform_matrix(self, other): - "Estimate the transformation matrix based on homography." + def estimate_transform_matrix(self, other: np.ndarray) -> np.ndarray: + """Estimates the transformation matrix between the template image and another image. + + This function detects keypoints and descriptors, matches them using k-nearest neighbors, + and applies Lowe's ratio test to filter for quality matches. + + Args: + other (np.ndarray): The image to estimate the transformation matrix against. + + Returns: + np.ndarray: The homography matrix that transforms the other image to align with the template image. + """ # find the keypoints and descriptors with SIFT kp1, descriptors1 = self.compute_descriptors(self.template) kp2, descriptors2 = self.compute_descriptors(other) @@ -55,17 +119,26 @@ def estimate_transform_matrix(self, other): M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0) return M - def transform_homography(self, other, min_axis=100, matrix=None): - """ - Run the image homography pipeline against a query image. + def transform_homography(self, other: np.ndarray, min_axis=100, matrix=None) -> np.ndarray: + """Run the full image homography pipeline against a query image. - Parameters: - min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform. - If the input image is under the axis limits, return the original input image unchanged. - matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be - estimated with `estimate_transform_matrix`. - """ + If the size of the `other` image is smaller than the minimum axis length `min_axis`, + the image is returned unchanged. + If a transformation matrix is provided, it is used directly; otherwise, the matrix is + estimated using `estimate_transform_matrix`. + + Args: + other (np.ndarray): The image to be transformed. + min_axis (int, optional): The minimum axis length (in pixels) to attempt the homography transform. + If the image is smaller, it will be returned unchanged. Default is 100. + matrix (np.ndarray, optional): The homography transformation matrix to apply. If not provided, + it will be estimated. + + Returns: + np.ndarray: The transformed image if homography was applied, or the original image if it is + smaller than the minimum axis size. + """ if other.shape[0] < min_axis and other.shape[1] < min_axis: return other diff --git a/OCR/ocr/services/alignment/backends/random_perspective_transform.py b/OCR/ocr/services/alignment/backends/random_perspective_transform.py index 3697d1df..34a35b38 100644 --- a/OCR/ocr/services/alignment/backends/random_perspective_transform.py +++ b/OCR/ocr/services/alignment/backends/random_perspective_transform.py @@ -1,6 +1,4 @@ -""" -Perspective transforms a base image between 10% and 90% distortion. -""" +"""Perspective transforms a base image between 10% and 90% distortion.""" from pathlib import Path @@ -10,15 +8,40 @@ class RandomPerspectiveTransform: - """Generate a random perspective transform based on a template `image`.""" + """Class to generates a random perspective transform based on a template image. - def __init__(self, image: Path): - self.image = Image.open(image) + This class allows you to apply random distortions to an image by computing + a perspective transformation matrix and warping the image accordingly. + + Attributes: + image (PIL.Image): The input image to which the perspective transform will be applied. + """ + + def __init__(self, image: Path) -> None: + """Initializes the RandomPerspectiveTransform instance with the given image path. - def make_transform(self, distortion_scale: float) -> object: + Args: + image (Path): Path to the image to be used for the transformation. """ - Create a transformation matrix for a random perspective transform. + self.image = Image.open(image) + + def make_transform(self, distortion_scale: float) -> np.ndarray: + """Create a transformation matrix for a random perspective transform. + + Args: + distortion_scale (float): A scale factor that controls the amount of distortion. + It should be between 0 (no distortion) and 1 (maximum distortion). + + Returns: + np.ndarray: The transformation matrix (from cv.PerspectiveTransform) that can be applied to this image. + + Raises: + ValueError: If `distortion_scale` is outside the range [0, 1). """ + if distortion_scale < 0 or distortion_scale >= 1: + raise ValueError("`distortion_scale` must be between 0 and 1") + + # We delay import until this is called due to torch's long initialization time. import torch # From torchvision. BSD 3-clause @@ -48,12 +71,33 @@ def make_transform(self, distortion_scale: float) -> object: np.array(endpoints, dtype=np.float32), np.array(startpoints, dtype=np.float32) ) - def transform(self, transformer) -> object: - """Warp the template image with a specified transform matrix.""" + def transform(self, transformer: np.ndarray) -> np.ndarray: + """Warp the template image with a specified transform matrix. + + Args: + transformer (np.ndarray): A perspective transformation matrix to apply. + + Returns: + np.ndarray: The warped image after applying the transformation. + """ return cv.warpPerspective(np.array(self.image), transformer, (self.image.width, self.image.height)) - def random_transform(self, distortion_scale: float) -> object: - """Warp the template image with specified `distortion_scale`.""" + def random_transform(self, distortion_scale: float) -> np.ndarray: + """Warp the template image with specified distortion_scale. + + This method internally calls `make_transform` to generate the transformation matrix + and applies it using `transform`. + + Args: + distortion_scale (float): A scale factor that controls the amount of distortion. + It should be between 0 (no distortion) and 1 (maximum distortion). + + Returns: + np.ndarray: The warped image after applying the random perspective transformation. + + Raises: + ValueError: If `distortion_scale` is outside the range [0, 1). + """ if distortion_scale < 0 or distortion_scale >= 1: raise ValueError("`distortion_scale` must be between 0 and 1") diff --git a/OCR/ocr/services/alignment/image_alignment.py b/OCR/ocr/services/alignment/image_alignment.py index d52472f4..155b2b73 100644 --- a/OCR/ocr/services/alignment/image_alignment.py +++ b/OCR/ocr/services/alignment/image_alignment.py @@ -1,19 +1,37 @@ +"""Module for aligning images using a specified image alignment backend.""" + import numpy as np from ocr.services.alignment.backends import ImageHomography class ImageAligner: + """Class for aligning images using a specified image alignment backend. + + Attributes: + aligner: An alignment backend class or instance that provides an `align` method. + Default is the ImageHomography backend from the ocr.services.alignment module. + """ + def __init__(self, aligner=ImageHomography): + """Initializes an ImageAligner instance with the specified image alignment backend. + + Args: + aligner (type): A class or instance of an alignment backend. Default is ImageHomography. + """ self.aligner = aligner def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: - """ - Aligns an image using the specified image alignment backend. + """Aligns the source image with the template image using the specified alignment backend. + + Args: + source_image (np.ndarray): The image to be aligned, represented as a NumPy array. + template_image (np.ndarray): The image that `source_image` will be aligned against, + represented as a NumPy array. May not be used for all + image alignment backends. - source_image: the image to be aligned, as a numpy ndarray. - template_image: the image that `source_image` will be aligned against, as a numpy ndarray. - May not be used for all image alignment backends. + Returns: + np.ndarray: The aligned image as a NumPy array. """ aligned_image = self.aligner.align(source_image, template_image) return aligned_image diff --git a/OCR/ocr/services/batch_metrics.py b/OCR/ocr/services/batch_metrics.py index ef5b0fb4..cec24e61 100644 --- a/OCR/ocr/services/batch_metrics.py +++ b/OCR/ocr/services/batch_metrics.py @@ -1,20 +1,49 @@ +"""Module for batch processing OCR and ground truth files and calculating metrics.""" + from ocr.services.metrics_analysis import OCRMetrics import os import csv class BatchMetricsAnalysis: - def __init__(self, ocr_folder, ground_truth_folder, csv_output_folder): + """Class for batch processing OCR and ground truth files, calculating OCR metrics, and saving the results in CSV files. + + Compare OCR results with ground truth data, calculates various OCR performance metrics (e.g., Levenshtein distance), + and saves individual CSV files for each pair of OCR and ground truth data. Problematic segments with a Levenshtein + distance greater than 1 (i.e., OCR results were incorrect) are identified and stored separately. + + Attributes: + ocr_folder (str): Path to the folder containing OCR result files. + ground_truth_folder (str): Path to the folder containing ground truth files. + csv_output_folder (str): Path to the folder where CSV output files will be saved. + """ + + def __init__(self, ocr_folder: str, ground_truth_folder: str, csv_output_folder: str) -> None: + """Initializes the BatchMetricsAnalysis class with paths to OCR, ground truth, and output folders. + + Creates the output folder if it doesn't exist. + + Args: + ocr_folder (str): Path to the folder containing OCR result files. + ground_truth_folder (str): Path to the folder containing ground truth files. + csv_output_folder (str): Path to the folder where CSV output files will be saved. + """ self.ocr_folder = ocr_folder self.ground_truth_folder = ground_truth_folder self.csv_output_folder = csv_output_folder os.makedirs(self.csv_output_folder, exist_ok=True) - def calculate_batch_metrics(self, ocr_results=None): - """ - Processes OCR and ground truth files and saves individual CSVs. + def calculate_batch_metrics(self, ocr_results=None) -> dict: + """Processes OCR and ground truth files and saves individual CSVs. + Ensures only matching files are processed. + + Args: + ocr_results (dict, optional): A dictionary of OCR results. + + Returns: + dict: A summary of total metrics for each processed OCR file. """ print(f"Loading OCR files from: {self.ocr_folder}") print(f"Loading ground truth files from: {self.ground_truth_folder}") @@ -68,9 +97,13 @@ def calculate_batch_metrics(self, ocr_results=None): return total_metrics_summary @staticmethod - def save_metrics_to_csv(metrics, total_metrics, file_path): - """ - Saves individual and total metrics to a CSV file, including time taken. + def save_metrics_to_csv(metrics: list, total_metrics: dict, file_path: str) -> None: + """Saves individual and total metrics to a CSV file, including time taken. + + Args: + metrics (list): A list of dictionaries containing individual metrics for each file. + total_metrics (dict): A dictionary containing the overall metrics for the batch. + file_path (str): Path to the CSV file where the metrics will be saved. """ print(metrics) metric_keys = list(metrics[0].keys()) @@ -92,9 +125,12 @@ def save_metrics_to_csv(metrics, total_metrics, file_path): print(f"Metrics saved to {file_path}") @staticmethod - def save_problematic_segments_to_csv(segments, file_path): - """ - Saves problematic segments (Levenshtein distance >= 1) to a CSV file. + def save_problematic_segments_to_csv(segments: list, file_path: str): + """Saves problematic segments (Levenshtein distance >= 1) to a CSV file. + + Args: + segments (list): A list of problematic segments, each represented as a dictionary. + file_path (str): Path to the CSV file where the problematic segments will be saved. """ if not segments: print("No problematic segments found.") @@ -109,9 +145,13 @@ def save_problematic_segments_to_csv(segments, file_path): print(f"Problematic segments saved to {file_path}") - def extract_problematic_segments(self, metrics, ocr_file, problematic_segments): - """ - Extracts segments with Levenshtein distance >= 1 and stores them. + def extract_problematic_segments(self, metrics: list, ocr_file: str, problematic_segments: list) -> None: + """Extracts segments with Levenshtein distance >= 1 and stores them. + + Args: + metrics (list): A list of dictionaries containing individual OCR metrics. + ocr_file (str): The OCR file name being processed. + problematic_segments (list): A list to store problematic segments. """ for metric in metrics: if metric["levenshtein_distance"] >= 1: @@ -127,10 +167,19 @@ def extract_problematic_segments(self, metrics, ocr_file, problematic_segments): ) @staticmethod - def get_files_in_directory(directory): - """ - Returns a sorted list of files in the specified directory. + def get_files_in_directory(directory: str) -> list: + """Returns a sorted list of files in the specified directory. + Assumes that files are named consistently for OCR and ground truth. + + Args: + directory (str): Path to the directory containing the files. + + Returns: + list: A sorted list of file names in the directory. + + Raises: + FileNotFoundError: If the directory is not found. """ try: files = sorted( diff --git a/OCR/ocr/services/batch_segmentation.py b/OCR/ocr/services/batch_segmentation.py index d2d80f34..ff37edbe 100644 --- a/OCR/ocr/services/batch_segmentation.py +++ b/OCR/ocr/services/batch_segmentation.py @@ -1,3 +1,5 @@ +"""Process and segment a batch of images and perform OCR on the results.""" + import os import json import time @@ -8,7 +10,28 @@ class BatchSegmentationOCR: - def __init__(self, image_folder, segmentation_template, labels_path, output_folder, model=None): + """Class that processes a batch of images by segmenting them and performing OCR on the segments. + + Attributes: + image_folder (str): Path to the folder containing images to process. + segmentation_template (str): Path to the segmentation template to guide image segmentation. + labels_path (str): Path to the file containing label data used for segmentation. + output_folder (str): Path to the folder where OCR results and timing information will be saved. + model (ImageOCR): An optional pre-defined OCR model; if None, a default instance of ImageOCR is used. + """ + + def __init__( + self, image_folder: str, segmentation_template: str, labels_path: str, output_folder: str, model=None + ) -> None: + """Initializes the BatchSegmentationOCR instance with the specified paths and an optional OCR model. + + Args: + image_folder (str): Path to the folder containing images to process. + segmentation_template (str): Path to the segmentation template to guide image segmentation. + labels_path (str): Path to the file containing label data used for segmentation. + output_folder (str): Path to the folder where OCR results and timing information will be saved. + model (ImageOCR, optional): An optional pre-defined OCR model. Defaults to `ImageOCR`. + """ self.image_folder = image_folder self.segmentation_template = segmentation_template self.labels_path = labels_path @@ -18,9 +41,11 @@ def __init__(self, image_folder, segmentation_template, labels_path, output_fold self.model = ImageOCR() os.makedirs(self.output_folder, exist_ok=True) - def process_images(self): - """ - Processes all images and returns OCR results with time taken. + def process_images(self) -> list[dict]: + """Processes all images and returns OCR results with time taken. + + Returns: + list[dict]: A list of dictionaries containing the OCR results and time taken for each image. """ segmenter = ImageSegmenter() ocr = self.model @@ -55,9 +80,21 @@ def process_images(self): print("Processing complete.") return results - def segment_ocr_image(self, segmenter, ocr, image_path, image_file): - """ - Segments the image and runs OCR, returning results and time taken. + def segment_ocr_image( + self, segmenter: ImageSegmenter, ocr, image_path: str, image_file: str + ) -> tuple[dict[str, tuple[str, float]], float]: + """Segments the image and runs OCR, returning results and time taken. + + Args: + segmenter (ImageSegmenter): An instance of the ImageSegmenter used to segment the image. + ocr (ImageOCR): An instance of the OCR model used to extract text from the segments. + image_path (str): Path to the image file to be processed. + image_file (str): The name of the image file. + + Returns: + tuple: A tuple containing: + - dict: The OCR results, where the key is the label and the value is a tuple of (text, confidence). + - float: The time taken to segment and process the image. """ start_time = time.time() @@ -77,9 +114,12 @@ def segment_ocr_image(self, segmenter, ocr, image_path, image_file): time_taken = time.time() - start_time return ocr_result, time_taken - def write_times_to_csv(self, time_dict, csv_output_path): - """ - Writes the time taken for each file to a CSV. + def write_times_to_csv(self, time_dict: dict[str, float], csv_output_path) -> None: + """Writes the time taken for each file to a CSV. + + Args: + time_dict (dict): A dictionary where the key is the image file name and the value is the time taken (in seconds). + csv_output_path (str): Path to the folder where the CSV file will be saved. """ csv_file_path = os.path.join(csv_output_path, "time_taken.csv") diff --git a/OCR/ocr/services/image_ocr.py b/OCR/ocr/services/image_ocr.py index c8d91f48..b90441d4 100644 --- a/OCR/ocr/services/image_ocr.py +++ b/OCR/ocr/services/image_ocr.py @@ -1,3 +1,5 @@ +"""Module for OCR using a transformers-based OCR model.""" + from collections.abc import Iterator from transformers import TrOCRProcessor, VisionEncoderDecoderModel @@ -7,14 +9,39 @@ class ImageOCR: + """A class for OCR using the transformers-based models. + + Defaults to using the Microsoft TrOCR model from Hugging Face's transformers library. + + Attributes: + processor (TrOCRProcessor): Processor for TrOCR model that prepares images for OCR. + model (VisionEncoderDecoderModel): Pre-trained TrOCR model for extracting text from images. + """ + def __init__(self, model="microsoft/trocr-large-printed"): + """Initializes the ImageOCR class with the specified OCR model. + + Args: + model (str, optional): The name of the pre-trained model to use. Default is "microsoft/trocr-large-printed". + + See Also: + * https://huggingface.co/microsoft/trocr-large-printed + """ self.processor = TrOCRProcessor.from_pretrained(model) self.model = VisionEncoderDecoderModel.from_pretrained(model) @staticmethod def compute_line_angle(lines: list) -> Iterator[float]: - """ - Takes the output of cv.HoughLinesP (in x1, y1, x2, y2 format) and computes the angle in degrees based on these endpoints. + """Computes the angle in degrees of the lines detected by the Hough transform, based on their endpoints. + + This method processes the output of `cv.HoughLinesP` (lines in (x1, y1, x2, y2) format) and computes the angle + between each line and the horizontal axis. + + Args: + lines (list): A list of lines represented as a list or tuple of endpoints [x1, y1, x2, y2]. + + Yields: + float: The angle (in degrees) of each line with respect to the horizontal axis. """ for line in lines: start = line[0][0:2] @@ -24,8 +51,16 @@ def compute_line_angle(lines: list) -> Iterator[float]: @staticmethod def merge_bounding_boxes(boxes: list) -> Iterator[list]: - """ - Merges overlapping boxes, passed in (x, y, w, h) format. + """Merges overlapping bounding boxes into a single bounding box. + + Given a list of bounding boxes in (x, y, w, h) format, this function merges overlapping boxes + into one larger box. + + Args: + boxes (list): A list of bounding boxes, where each box is represented as a list or tuple [x, y, w, h]. + + Yields: + list: Merged bounding boxes, represented in [x, y, w, h] format. """ if not boxes: return [] @@ -53,9 +88,17 @@ def merge_bounding_boxes(boxes: list) -> Iterator[list]: yield [current[0], current[1], current[2] - current[0], current[3] - current[1]] def identify_blocks(self, input_image: np.ndarray, kernel: np.ndarray): - """ - Given an input image and a morphological operation kernel, returns unique (non-overlapping) - bounding boxes of potential text regions. + """Identifies potential text blocks in an image by applying morphological operations. + + The function uses the input image, applies thresholding and dilation, and then finds contours to identify + potential text blocks. It then merges overlapping bounding boxes into larger ones. + + Args: + input_image (np.ndarray): The input image to process. + kernel (np.ndarray): The kernel used for morphological operations (dilation). + + Returns: + Iterator[list]: An iterator of merged bounding boxes, each represented as [x, y, w, h]. """ # Invert threshold `input_image` and dilate using `kernel` to "expand" the size of text blocks _, thresh = cv.threshold(cv.cvtColor(input_image, cv.COLOR_BGR2GRAY), 128, 255, cv.THRESH_BINARY_INV) @@ -68,14 +111,17 @@ def identify_blocks(self, input_image: np.ndarray, kernel: np.ndarray): return self.merge_bounding_boxes([cv.boundingRect(contour) for contour in contours]) def deskew_image_text(self, image: np.ndarray, line_length_prop=0.5, max_skew_angle=10) -> np.ndarray: - """ - Deskew an image using Hough transforms to detect lines. + """Deskew an image using Hough transforms to detect lines and rotating the image to correct any skew. Since even small-angled skews can compromise the line segmentation algorithm, this is needed as a preprocessing step. - line_length_prop: typical line length as a fraction of the horizontal size of an image. - max_skew_angle: maximum angle in degrees that a putative line can be skewed before it is removed from consideration - for being too skewed. + Args: + image (np.ndarray): The image to be deskewed. + line_length_prop (float, optional): Proportion of the image's width used to determine line length. Default is 0.5. + max_skew_angle (float, optional): Maximum allowed skew angle for valid lines (in degrees). Default is 10. + + Returns: + np.ndarray: The deskewed image. """ line_length = image.shape[1] * line_length_prop # Flatten image to grayscale for edge detection @@ -95,10 +141,17 @@ def deskew_image_text(self, image: np.ndarray, line_length_prop=0.5, max_skew_an return cv.warpAffine(np.array(image, dtype=np.uint8), rotation_mat, (image.shape[1], image.shape[0])) def split_text_blocks(self, image: np.ndarray, line_length_prop=0.5) -> list[np.ndarray]: - """ - Splits an image with text in it into possibly multiple images, one for each line. + """Splits an image with text in it into (possibly) multiple images, one for each line. + + The function first deskews the image, then uses morphological operations to identify potential lines and words. + It then separates the image into individual text blocks (lines and words). + + Args: + image (np.ndarray): The image to split into text blocks. + line_length_prop (float, optional): Proportion of the image's width used to determine the typical line length. Default is 0.5. - line_length_prop: typical line length as a fraction of the horizontal size of an image. + Returns: + list[np.ndarray]: A list of images representing individual text blocks. """ line_length = image.shape[1] * line_length_prop rotated = self.deskew_image_text(image, line_length_prop) @@ -130,6 +183,18 @@ def split_text_blocks(self, image: np.ndarray, line_length_prop=0.5) -> list[np. return acc def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]: + """Converts image segments into text using Transformers OCR. + + For each segment, it extracts the text and the average confidence score. + + Args: + segments (dict[str, np.ndarray]): A dictionary where keys are segment labels (e.g., 'name', 'date'), + and values are NumPy arrays representing the corresponding image segments. + + Returns: + dict[str, tuple[str, float]]: A dictionary where each key corresponds to a segment label, and each value is + a tuple containing the OCR result (string) and the confidence score (float). + """ digitized: dict[str, tuple[str, float]] = {} for label, image in segments.items(): if image is None: @@ -155,6 +220,14 @@ def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, return digitized def calculate_confidence(self, outputs): + """Calculates the confidence level of the OCR output. + + Args: + outputs: The output of the model, containing prediction scores. + + Returns: + float: The confidence percentage of the OCR output. + """ probs = torch.softmax(outputs.scores[0], dim=-1) max_probs = torch.max(probs, dim=-1).values diff --git a/OCR/ocr/services/image_segmenter.py b/OCR/ocr/services/image_segmenter.py index 6378727b..b1eb8b1c 100644 --- a/OCR/ocr/services/image_segmenter.py +++ b/OCR/ocr/services/image_segmenter.py @@ -1,10 +1,21 @@ +"""Module to segment images based on a segmentation template and a set of labels.""" + import cv2 as cv import numpy as np import json import os -def crop_zeros(image): +def crop_zeros(image: np.ndarray) -> np.ndarray: + """Crops the given image to remove all-zero (black) regions. + + Args: + image (np.ndarray): The input image represented as a NumPy array, + where zero values represent black areas to be cropped. + + Returns: + np.ndarray: A cropped version of the image with zero regions removed. + """ # argwhere will give you the coordinates of every non-zero point true_points = np.argwhere(image) @@ -21,7 +32,21 @@ def crop_zeros(image): ] # inclusive -def segment_by_mask_then_crop(raw_image, segmentation_template, labels, debug) -> dict[str, np.ndarray]: +def segment_by_mask_then_crop( + raw_image: np.ndarray, segmentation_template: np.ndarray, labels: list[dict[str, str]], debug: bool +) -> dict[str, np.ndarray]: + """Segments a raw image based on a color mask in the segmentation template, and then crops the resulting regions to remove zero (black) areas. + + Args: + raw_image (np.ndarray): The input image to be segmented, as a NumPy array. + segmentation_template (np.ndarray): A template image used for segmenting the raw image with color masks. + labels (list[dict[str, str]]): A list of dicts containing 'label' and 'color' keys, where 'label' is the segment label + and 'color' is the color to match in the segmentation template. + debug (bool): If `True`, saves debug images and prints additional information. + + Returns: + dict[str, np.ndarray]: A dictionary where keys are segment labels and values are the cropped segmented images. + """ segments = {} # iterate over the labels @@ -53,7 +78,21 @@ def segment_by_mask_then_crop(raw_image, segmentation_template, labels, debug) - return segments -def segment_by_color_bounding_box(raw_image, segmentation_template, labels, debug) -> dict[str, np.ndarray]: +def segment_by_color_bounding_box( + raw_image: np.ndarray, segmentation_template: np.ndarray, labels: list[dict[str, str]], debug: bool +) -> dict[str, np.ndarray]: + """Segments a raw image by detecting colored boundary boxes in the segmentation template. + + Args: + raw_image (np.ndarray): The input image to be segmented, as a NumPy array. + segmentation_template (np.ndarray): A template image used for segmenting the raw image with colored boxes. + labels (list[dict[str, str]]): A list of dicts containing 'label' and 'color' keys, where 'label' is the segment label + and 'color' is the color to match in the segmentation template. + debug (bool): If `True`, saves debug images and prints additional information. + + Returns: + dict[str, np.ndarray]: A dictionary where keys are segment labels and values are the cropped segmented images. + """ segments = {} # iterate over the labels @@ -76,23 +115,66 @@ def segment_by_color_bounding_box(raw_image, segmentation_template, labels, debu class ImageSegmenter: + """A class for segmenting images based on a segmentation template and labels. + + Supports different segmentation strategies by passing in functions to `segmentation_function`. + + Attributes: + segmentation_function (function): A function used for segmenting the image, such as + `segment_by_mask_then_crop` or `segment_by_color_bounding_box`. + debug (bool): If `True`, saves debug images and prints additional information. + """ + def __init__( self, segmentation_function=segment_by_mask_then_crop, debug=False, ): + """Initializes the ImageSegmenter with the specified segmentation function. + + Args: + segmentation_function (function): The segmentation function to use. Default is `segment_by_mask_then_crop`. + debug (bool): If `True`, saves debug images and prints additional information. + """ self.segmentation_function = segmentation_function self.debug = debug def segment( self, - raw_image, - segmentation_template, - labels, + raw_image: np.ndarray, + segmentation_template: np.ndarray, + labels: list[dict[str, str]], ) -> dict[str, np.ndarray]: + """Segments a raw image using the class instance's segmentation function. + + Args: + raw_image (np.ndarray): The input image to be segmented, as a NumPy array. + segmentation_template (np.ndarray): A template image used for segmenting the raw image. + labels (list[dict[str, str]]): A list of dicts containing 'label' and 'color' keys, where 'label' is the segment label + and 'color' is the color to match in the segmentation template. + + Returns: + dict[str, np.ndarray]: A dictionary where keys are segment labels and values are the cropped segmented images. + """ return self.segmentation_function(raw_image, segmentation_template, labels, self.debug) - def load_and_segment(self, raw_image_path, segmentation_template_path, labels_path): + def load_and_segment( + self, raw_image_path: str, segmentation_template_path: str, labels_path: str + ) -> dict[str, np.ndarray]: + """Loads image files and labels from specified paths, and then segments the image. + + Args: + raw_image_path (str): Path to the raw image file. + segmentation_template_path (str): Path to the segmentation template image. + labels_path (str): Path to the JSON file containing the segment labels and colors. + + Returns: + dict[str, np.ndarray]: A dictionary where keys are segment labels and values are the cropped segmented images. + + Raises: + FileNotFoundError: If any of the input files do not exist. + ValueError: If an image file cannot be opened. + """ if ( not os.path.isfile(raw_image_path) or not os.path.isfile(segmentation_template_path) diff --git a/OCR/ocr/services/metrics_analysis.py b/OCR/ocr/services/metrics_analysis.py index a6c7109c..1370c49c 100644 --- a/OCR/ocr/services/metrics_analysis.py +++ b/OCR/ocr/services/metrics_analysis.py @@ -1,3 +1,5 @@ +"""Module to calculate OCR metrics and compare results to ground truth data.""" + import json import csv import Levenshtein @@ -5,34 +7,65 @@ class OCRMetrics: - """ - A class to calculate and manage OCR metrics. + """Class to calculate and manage OCR metrics, comparing OCR results with ground truth data. + + This class computes various OCR performance metrics such as raw distance, Hamming distance, and Levenshtein distance + between the OCR output and the ground truth. It also calculates overall accuracy and confidence metrics. + Attributes: + ocr_json (dict): A dict from JSON data containing OCR results. + ground_truth_json (dict): A dict from JSON data containing the ground truth values. """ - def __init__( - self, ocr_json_path=None, ground_truth_json_path=None, ocr_json=None, ground_truth_json=None, testMode=False - ): + def __init__(self, ocr_json_path=None, ground_truth_json_path=None, ocr_json=None, ground_truth_json=None): + """Initializes the OCRMetrics object with OCR and ground truth data, either loaded from files or provided as dictionaries. + + Args: + ocr_json_path (str, optional): Path to the OCR JSON file. + ground_truth_json_path (str, optional): Path to the ground truth JSON file. + ocr_json (dict, optional): The OCR results as a dictionary. + ground_truth_json (dict, optional): The ground truth values as a dictionary. + + Raises: + ValueError: If both file paths and dictionaries are provided. """ - Parameters: - ocr_json (dict): The JSON data extracted from OCR. - ground_truth_json (dict): The JSON data containing ground truth. + if ocr_json and ocr_json_path: + raise ValueError("Cannot specify both OCR results dict and JSON path!") + if ground_truth_json and ground_truth_json_path: + raise ValueError("Cannot specify both OCR ground truth dict and JSON path!") + + if ocr_json_path: + ocr_json = self.load_json_file(ocr_json_path) + if ground_truth_json_path: + ground_truth_json = self.load_json_file(ground_truth_json_path) + + self.ocr_json = ocr_json + self.ground_truth_json = ground_truth_json + + def load_json_file(self, file_path: str) -> dict | None: + """Loads JSON data from a file. + + Args: + file_path (str): Path to the JSON file. + + Returns: + dict: Parsed JSON data as a dictionary, or None if the file path was not passed. """ - if testMode: - self.ocr_json = ocr_json - self.ground_truth_json = ground_truth_json - else: - self.ocr_json = self.load_json_file(ocr_json_path) - self.ground_truth_json = self.load_json_file(ground_truth_json_path) - - def load_json_file(self, file_path): if file_path: with open(file_path, "r") as file: data = json.load(file) return data @staticmethod - def normalize(text): + def normalize(text: str | None) -> str: + """Normalizes text by stripping whitespace, converting to lowercase, and collapsing multiple spaces. + + Args: + text (str or None): The input text to normalize. + + Returns: + str: The normalized text. + """ if text is None: return "" @@ -41,20 +74,58 @@ def normalize(text): return " ".join(text.strip().lower().split()) @staticmethod - def raw_distance(ocr_text, ground_truth): + def raw_distance(ocr_text: str, ground_truth: str) -> int: + """Calculates a raw distance between text strings by the difference in length. + + Args: + ocr_text (str): The OCR-generated text. + ground_truth (str): The ground truth text. + + Returns: + int: The difference in length between the OCR text and ground truth. + """ return len(ground_truth) - len(ocr_text) @staticmethod - def hamming_distance(ocr_text, ground_truth): + def hamming_distance(ocr_text: str, ground_truth: str) -> int: + """Calculates the Hamming distance between two strings, assuming they have the same length. + + Args: + ocr_text (str): The OCR-generated text. + ground_truth (str): The ground truth text. + + Returns: + int: The number of positions where the characters differ. + + Raises: + ValueError: If the strings are not of the same length. + """ if len(ocr_text) != len(ground_truth): raise ValueError("Strings must be of the same length to calculate Hamming distance.") return Levenshtein.hamming(ocr_text.upper(), ground_truth.upper()) @staticmethod - def levenshtein_distance(ocr_text, ground_truth): + def levenshtein_distance(ocr_text: str, ground_truth: str) -> int: + """Calculates the Levenshtein distance between two strings. + + Args: + ocr_text (str): The OCR-generated text. + ground_truth (str): The ground truth text. + + Returns: + int: The minimum number of single-character edits required to change one string into the other. + """ return Levenshtein.distance(ocr_text.upper(), ground_truth.upper()) - def extract_values_from_json(self, json_data): + def extract_values_from_json(self, json_data: dict) -> dict: + """Extracts and normalizes the values from JSON data. + + Args: + json_data (dict): The input JSON data. + + Returns: + dict: A dictionary containing normalized data. + """ if json_data is None: return {} extracted_values = {} @@ -74,7 +145,17 @@ def extract_values_from_json(self, json_data): return extracted_values - def calculate_metrics(self): + def calculate_metrics(self) -> list[dict[str, str | float | int]]: + """Calculates OCR performance metrics for each key-value pair in the ground truth data. + + This compares the OCR output to ground-truth data and calculates: + * Raw distance + * Hamming distance + * Levenshtein distance + + Returns: + list: A list of dictionaries containing, for each key, the OCR output, ground truth data, confidence, and distance metrics. + """ ocr_values = self.extract_values_from_json(self.ocr_json) ground_truth_values = self.extract_values_from_json(self.ground_truth_json) metrics = [] @@ -105,7 +186,15 @@ def calculate_metrics(self): return metrics @staticmethod - def total_metrics(metrics): + def total_metrics(metrics: dict) -> dict[str, int | float]: + """Summarizes many OCR metrics and calculates total distances and accuracy. + + Args: + metrics (list): A list of dictionaries containing individual metrics for each key. + + Returns: + dict: A dictionary containing summary metrics. + """ total_raw_distance = sum(item["raw_distance"] for item in metrics if isinstance(item["raw_distance"], int)) total_levenshtein_distance = sum( item["levenshtein_distance"] for item in metrics if isinstance(item["levenshtein_distance"], int) @@ -134,7 +223,14 @@ def total_metrics(metrics): } @staticmethod - def save_metrics_to_csv(metrics, total_metrics, file_path): + def save_metrics_to_csv(metrics: list, total_metrics: dict, file_path: str) -> None: + """Saves OCR metrics and summarized metrics to a CSV file. + + Args: + metrics (list): A list of dictionaries containing individual metrics. + total_metrics (dict): A dictionary containing the summary metrics. + file_path (str): The path where metrics will be saved as a CSV file. + """ metric_keys = metrics[0].keys() total_metric_keys = total_metrics.keys() diff --git a/OCR/ocr/services/phdc_converter/builder.py b/OCR/ocr/services/phdc_converter/builder.py index c2aeba85..a78def7f 100644 --- a/OCR/ocr/services/phdc_converter/builder.py +++ b/OCR/ocr/services/phdc_converter/builder.py @@ -14,22 +14,19 @@ class PHDC: - """ - A class to represent a Public Health Data Container (PHDC) document given a + """A class to represent a Public Health Data Container (PHDC) document given a PHDCBuilder. """ def __init__(self, data: ET.ElementTree = None): - """ - Initializes the PHDC class with a PHDCBuilder. + """Initializes the PHDC class with a PHDCBuilder. :param builder: The PHDCBuilder to use to build the PHDC. """ self.data = data def to_xml_string(self) -> bytes: - """ - Return a string representation of the PHDC XML document as serialized bytes. + """Return a string representation of the PHDC XML document as serialized bytes. :return: The PHDC XML document as serialized bytes. """ @@ -43,8 +40,7 @@ def to_xml_string(self) -> bytes: ).decode() def to_element_tree(self) -> ET.ElementTree: - """ - Return the PHDC XML document as an ElementTree. + """Return the PHDC XML document as an ElementTree. :return: The PHDC XML document as an ElementTree. """ @@ -54,30 +50,22 @@ def to_element_tree(self) -> ET.ElementTree: class PHDCBuilder: - """ - A builder class for creating PHDC documents. - """ + """A builder class for creating PHDC documents.""" def __init__(self): - """ - Initializes the PHDCBuilder class and create and empty PHDC. - """ - + """Initializes the PHDCBuilder class and create and empty PHDC.""" self.input_data: PHDCInputData = None self.phdc = self._build_base_phdc() def set_input_data(self, input_data: PHDCInputData): - """ - Given a PHDCInputData object, set the input data for the PHDCBuilder. + """Given a PHDCInputData object, set the input data for the PHDCBuilder. :param input_data: The PHDCInputData object to use as input data. """ - self.input_data = input_data def _build_base_phdc(self) -> ET.ElementTree: - """ - Create the base PHDC XML document. + """Create the base PHDC XML document. :return: The base PHDC XML document. """ @@ -107,8 +95,7 @@ def _build_base_phdc(self) -> ET.ElementTree: return clinical_document def _get_type_id(self) -> ET.Element: - """ - Creates the type ID element of the PHDC header. + """Creates the type ID element of the PHDC header. :return: XML element of . """ @@ -118,8 +105,7 @@ def _get_type_id(self) -> ET.Element: return type_id def _get_id(self) -> ET.Element: - """ - Creates the ID element of the PHDC header. + """Creates the ID element of the PHDC header. :return: XML element of . """ @@ -129,8 +115,7 @@ def _get_id(self) -> ET.Element: return id def _get_effective_time(self) -> ET.Element: - """ - Creates the effectiveTime element of the PHDC header. + """Creates the effectiveTime element of the PHDC header. :return: XML element of . """ @@ -141,8 +126,7 @@ def _get_effective_time(self) -> ET.Element: def _get_confidentiality_code( self, confidentiality: Literal["normal", "restricted", "very restricted"] ) -> ET.Element: - """ - Creates the confidentialityCode element of the PHDC header. + """Creates the confidentialityCode element of the PHDC header. :param confidentiality: The confidentiality code to use. :return: XML element of . @@ -159,19 +143,16 @@ def _get_confidentiality_code( return confidentiality_code def _get_realmCode(self) -> ET.Element: - """ - Creates the realmCode element of the PHDC header. + """Creates the realmCode element of the PHDC header. :return: XML element of . """ - realmCode = ET.Element("realmCode") realmCode.set("code", "US") return realmCode def _get_clinical_info_code(self) -> ET.Element: - """ - Creates the code element of the header for a PHDC case report. + """Creates the code element of the header for a PHDC case report. :return: XML element of . """ @@ -183,8 +164,7 @@ def _get_clinical_info_code(self) -> ET.Element: return code def _get_title(self) -> ET.Element: - """ - Creates the title element of the PHDC header. + """Creates the title element of the PHDC header. :return: XML element of . """ @@ -193,8 +173,7 @@ def _get_title(self) -> ET.Element: return title def _get_setId(self) -> ET.Element: - """ - Creates the setId element of the PHDC header. + """Creates the setId element of the PHDC header. :return: XML element of <setId>. """ @@ -204,8 +183,7 @@ def _get_setId(self) -> ET.Element: return setid def _get_version_number(self) -> ET.Element: - """ - Returns the versionNumber element of the PHDC header. + """Returns the versionNumber element of the PHDC header. :return: XML element of <versionNumber>. """ @@ -217,9 +195,7 @@ def _get_version_number(self) -> ET.Element: return version_number def build_header(self): - """ - Builds the header of the PHDC document. - """ + """Builds the header of the PHDC document.""" root = self.phdc.getroot() root.append(self._get_realmCode()) root.append(self._get_type_id()) @@ -266,8 +242,7 @@ def _add_observations_to_section( section: ET.Element, data: ET.Element, ) -> ET.Element: - """ - Adds Clinical Observation and Social History Information observations to the + """Adds Clinical Observation and Social History Information observations to the appropriate section. :param section: Section XML element. @@ -283,8 +258,7 @@ def _add_observations_to_section( return section def _build_clinical_info(self) -> ET.Element: - """ - Builds the `ClinicalInformation` XML element, including all hardcoded aspects + """Builds the `ClinicalInformation` XML element, including all hardcoded aspects required to initialize the section. :param observation_data: List of clinical-relevant Observation data. @@ -316,8 +290,7 @@ def _build_clinical_info(self) -> ET.Element: return component def _build_social_history_info(self) -> ET.Element: - """ - Builds the Social History Information XML section, including all hardcoded + """Builds the Social History Information XML section, including all hardcoded aspects required to initialize the section. :return: XML element of SocialHistory data. """ @@ -351,8 +324,7 @@ def _build_social_history_info(self) -> ET.Element: return component def _build_repeating_questions(self) -> ET.Element: - """ - Builds the Repeating Questions XML section, including all hardcoded + """Builds the Repeating Questions XML section, including all hardcoded aspects required to initialize the section. :return: XML element of Repeating Questions data. """ @@ -413,8 +385,7 @@ def _build_repeating_questions(self) -> ET.Element: return component_section def _build_telecom(self, telecom: Telecom) -> ET.Element: - """ - Builds a `telecom` XML element for phone data including phone number (as + """Builds a `telecom` XML element for phone data including phone number (as `value`) and use, if available. There are three types of phone uses: 'HP' for home phone, 'WP' for work phone, and 'MC' for mobile phone. @@ -442,8 +413,7 @@ def _build_telecom(self, telecom: Telecom) -> ET.Element: return telecom_data def _add_field(self, parent_element: ET.Element, data: str, field_name: str): - """ - Adds a child element to a parent element given the data and field name. + """Adds a child element to a parent element given the data and field name. :param parent_element: The parent element to add the child element to. :param data: The data to add to the child element. @@ -455,8 +425,7 @@ def _add_field(self, parent_element: ET.Element, data: str, field_name: str): parent_element.append(e) def _build_observation(self, observation: Observation) -> ET.Element: - """ - Creates Entry XML element for observation data. + """Creates Entry XML element for observation data. :param observation: The data for building the observation element as an Entry object. @@ -496,8 +465,7 @@ def _build_observation(self, observation: Observation) -> ET.Element: return observation_data def _set_value_xsi_type(self, observation: Observation) -> Observation: - """ - Ensure that observation elements with a value child element use + """Ensure that observation elements with a value child element use the correct namespace based on the data. :param observation: The observation data being used in _build_observation @@ -540,8 +508,7 @@ def _build_addr( self, address: Address, ) -> ET.Element: - """ - Builds an `addr` XML element for address data. There are two types of address + """Builds an `addr` XML element for address data. There are two types of address uses: 'H' for home address and 'WP' for workplace address. :param address: The data for building the address element as an Address object. @@ -563,13 +530,11 @@ def _build_addr( return address_data def _build_name(self, name: Name) -> ET.Element: - """ - Builds a `name` XML element for name data. + """Builds a `name` XML element for name data. :param name: The data for constructing the name element as a Name object. :return: XML element of name data. """ - name_data = ET.Element("name") if name.type is not None: @@ -591,8 +556,7 @@ def _build_name(self, name: Name) -> ET.Element: return name_data def _build_patient(self, patient: Patient) -> ET.Element: - """ - Given a Patient object, build the patient element of the PHDC. + """Given a Patient object, build the patient element of the PHDC. :param patient: The Patient object to use for building the patient element. :return: XML element of patient data. @@ -671,8 +635,7 @@ def _build_recordTarget( address_data: Optional[List[Address]] = None, patient_data: Optional[Patient] = None, ) -> ET.Element: - """ - Builds a `recordTarget` XML element for recordTarget data, which refers to + """Builds a `recordTarget` XML element for recordTarget data, which refers to the medical record of the patient. :param id: recordTarget identifier @@ -731,8 +694,7 @@ def _build_recordTarget( return recordTarget_data def build(self) -> PHDC: - """ - Constructs a PHDC document by building the header and body components. + """Constructs a PHDC document by building the header and body components. :return: A PHDC document as an instance of the PHDC class. """ diff --git a/OCR/ocr/services/phdc_converter/models.py b/OCR/ocr/services/phdc_converter/models.py index 792c7855..4069d956 100644 --- a/OCR/ocr/services/phdc_converter/models.py +++ b/OCR/ocr/services/phdc_converter/models.py @@ -9,9 +9,7 @@ @dataclass class Telecom: - """ - A class containing all of the data elements for a telecom element. - """ + """A class containing all of the data elements for a telecom element.""" value: Optional[str] = None type: Optional[str] = None @@ -21,9 +19,7 @@ class Telecom: @dataclass class Address: - """ - A class containing all of the data elements for an address element. - """ + """A class containing all of the data elements for an address element.""" street_address_line_1: Optional[str] = None street_address_line_2: Optional[str] = None @@ -39,9 +35,7 @@ class Address: @dataclass class Name: - """ - A class containing all of the data elements for a name element. - """ + """A class containing all of the data elements for a name element.""" prefix: Optional[str] = None first: Optional[str] = None @@ -55,9 +49,7 @@ class Name: @dataclass class Patient: - """ - A class containing all of the data elements for a patient element. - """ + """A class containing all of the data elements for a patient element.""" name: List[Name] = None address: List[Address] = None @@ -70,9 +62,7 @@ class Patient: @dataclass class Organization: - """ - A class containing all of the data elements for an organization element. - """ + """A class containing all of the data elements for an organization element.""" id: str = None name: str = None @@ -82,9 +72,7 @@ class Organization: @dataclass class CodedElement: - """ - A class containing all of the data elements for a coded element. - """ + """A class containing all of the data elements for a coded element.""" xsi_type: Optional[str] = None code: Optional[str] = None @@ -95,8 +83,7 @@ class CodedElement: text: Optional[Union[str, int]] = None def to_attributes(self) -> Dict[str, str]: - """ - Given a standard CodedElements return a dictionary that can be iterated over to + """Given a standard CodedElements return a dictionary that can be iterated over to produce the corresponding XML element. :return: A dictionary of the CodedElement's attributes @@ -116,9 +103,7 @@ def to_attributes(self) -> Dict[str, str]: @dataclass class Observation: - """ - A class containing all of the data elements for an observation element. - """ + """A class containing all of the data elements for an observation element.""" obs_type: str = "laboratory" type_code: Optional[str] = None @@ -154,8 +139,7 @@ class Observation: @dataclass class PHDCInputData: - """ - A class containing all of the data to construct a PHDC document when passed to the + """A class containing all of the data to construct a PHDC document when passed to the PHDCBuilder. """ diff --git a/OCR/ocr/services/phdc_converter/phdc_converter.py b/OCR/ocr/services/phdc_converter/phdc_converter.py index 7a51d7b4..8d92735d 100644 --- a/OCR/ocr/services/phdc_converter/phdc_converter.py +++ b/OCR/ocr/services/phdc_converter/phdc_converter.py @@ -3,9 +3,7 @@ class PHDCConverter: - """ - Parse the OCR data converted to json to create an instance of the Patient data class. - """ + """Parse the OCR data converted to json to create an instance of the Patient data class.""" def parse_patient_data(self, json_data): name = Name(first=json_data.get("patient_first_name", ""), family=json_data.get("patient_last_name", "")) @@ -31,9 +29,7 @@ def parse_patient_data(self, json_data): return patient def generate_phdc_document(self, json_data): - """ - Generate the PHDC document using parsed OCR data. - """ + """Generate the PHDC document using parsed OCR data.""" patient = self.parse_patient_data(json_data) phdc_input = PHDCInputData(patient=patient, type="case_report") diff --git a/OCR/ocr/services/tesseract_ocr.py b/OCR/ocr/services/tesseract_ocr.py index c796a065..933007e5 100644 --- a/OCR/ocr/services/tesseract_ocr.py +++ b/OCR/ocr/services/tesseract_ocr.py @@ -1,3 +1,5 @@ +"""Module for OCR services using a Tesseract backend.""" + import os import tesserocr @@ -7,30 +9,46 @@ class TesseractOCR: + """A class to provide OCR services using Tesseract as the backend. + + This class supports configuring Tesseract's page segmentation modes and customizing its behavior + through internal variables. + + Attributes: + psm (int): The page segmentation mode for Tesseract, specifying how Tesseract interprets the structure of the document. + variables (dict): A dictionary of variables to customize Tesseract's behavior. + + See Also: + * https://github.com/sirfz/tesserocr/blob/bbe0fb8edabdcc990f1e6fa9334c0747c2ac76ee/tesserocr/__init__.pyi#L47 + * https://tesseract-ocr.github.io/tessdoc/tess3/ControlParams.html + """ + def __init__(self, psm=PSM.AUTO, variables=dict()): - """ - Initialize the tesseract OCR model. + """Initializes the TesseractOCR object with the specified page segmentation mode and internal variables. - `psm` (int): an enum (from `PSM`) that defines tesseract's page segmentation mode. Default is `AUTO`. - `variables` (dict): a dict to customize tesseract's behavior with internal variables + Args: + psm (int, optional): The page segmentation mode (from `tesserocr.PSM`). Default is `PSM.AUTO`. + variables (dict, optional): A dictionary of variables to customize Tesseract's behavior. Default is an empty dictionary. """ self.psm = psm self.variables = variables @staticmethod def _guess_tessdata_path(wanted_lang="eng") -> bytes: - """ - Attempts to guess potential locations for the `tessdata` folder. + """Attempts to guess potential locations for the `tessdata` folder. The `tessdata` folder is needed to use pre-trained Tesseract OCR data, though the automatic detection - provided in `tesserocr` may not be reliable. Instead iterate over common paths on various systems (e.g., - Red Hat, Ubuntu, macOS) and check for the presence of a `tessdata` folder. + provided in `tesserocr` may not be reliable. - If `TESSDATA_PREFIX` is available in the environment, this function will check that location first. - If all guessed locations do not exist, fall back to automatic detection provided by `tesserocr` and - the tesseract API. + The function first checks the path defined by the environment variable `TESSDATA_PREFIX` (if available), + and then falls back to searching several default candidate paths on various systems (e.g., Red Hat, Ubuntu, + macOS). If no valid path is found, it uses the automatic detection provided by the Tesseract API, which may fail. - `wanted_lang` (str): a desired language to search for. Defaults to English `eng`. + Args: + wanted_lang (str, optional): The desired language to search for in the `tessdata` folder. Default is 'eng' (English). + + Returns: + bytes: The path to the `tessdata` directory containing the OCR language files. """ candidate_paths = [ "/usr/local/share/tesseract/tessdata", @@ -63,6 +81,21 @@ def _guess_tessdata_path(wanted_lang="eng") -> bytes: return tesserocr.get_languages()[0] def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]: + """Converts image segments into text using Tesseract OCR. + + The function processes a dictionary of image segments, where each key corresponds to a segment label, + and each value is a NumPy array representing an image segment. + + For each segment, it extracts the text and the average confidence score returned from the Tesseract API. + + Args: + segments (dict[str, np.ndarray]): A dictionary where keys are segment labels (e.g., 'name', 'date'), + and values are NumPy arrays representing the corresponding image segments. + + Returns: + dict[str, tuple[str, float]]: A dictionary where each key corresponds to a segment label, and each value is + a tuple containing the OCR result (string) and the confidence score (float). + """ digitized: dict[str, tuple[str, float]] = {} with tesserocr.PyTessBaseAPI(psm=self.psm, variables=self.variables, path=self._guess_tessdata_path()) as api: for label, image in segments.items(): diff --git a/OCR/pyproject.toml b/OCR/pyproject.toml index 1380120a..0eefa3c1 100644 --- a/OCR/pyproject.toml +++ b/OCR/pyproject.toml @@ -38,3 +38,23 @@ build = "ocr.pyinstaller:install" [tool.ruff] line-length = 118 target-version = "py310" + +[tool.ruff.lint] +select = ["D"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +# Ignore test directories +"tests/**" = ["D"] + +# Ignore module-level init.py +"__init__.py" = ["D"] + +# Ignore CLI entry points +"main.py" = ["D"] +"**_main.py" = ["D"] + +# Ignore phdc_converter (already documented in a different style) +"ocr/services/phdc_converter/**" = ["D"] diff --git a/OCR/tests/assets/ocr_metrics_test.py b/OCR/tests/assets/ocr_metrics_test.py index d5d2f3a1..7b03cc7b 100644 --- a/OCR/tests/assets/ocr_metrics_test.py +++ b/OCR/tests/assets/ocr_metrics_test.py @@ -9,7 +9,7 @@ def ocr_metrics(): {"key": "Name", "value": "John Doe"}, {"key": "Date of Birth", "value": "1990-01-01"}, ] - return OCRMetrics(ocr_json=ocr_data, ground_truth_json=ground_truth_data, testMode=True) + return OCRMetrics(ocr_json=ocr_data, ground_truth_json=ground_truth_data) def test_calculate_metrics(ocr_metrics):