Skip to content

Commit

Permalink
add docstrings and lint enforcement (#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonchang authored Jan 16, 2025
1 parent 50849d7 commit 0f8d189
Show file tree
Hide file tree
Showing 17 changed files with 817 additions and 237 deletions.
7 changes: 2 additions & 5 deletions OCR/benchmark_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ def main():


def run_segmentation_and_ocr(args):
"""
Runs segmentation and OCR processing.
"""Runs segmentation and OCR processing.
Returns OCR results with processing time.
"""

model = None

if args.model == "tesseract":
Expand All @@ -70,8 +68,7 @@ def run_segmentation_and_ocr(args):


def run_metrics_analysis(args, ocr_results):
"""
Runs metrics analysis based on OCR output and ground truth.
"""Runs metrics analysis based on OCR output and ground truth.
Uses OCR results to capture time values if available.
"""
metrics_analysis = BatchMetricsAnalysis(args.output_folder, args.ground_truth_folder, args.csv_output_folder)
Expand Down
74 changes: 69 additions & 5 deletions OCR/ocr/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Module for FastAPI interface functions."""

import base64

import uvicorn
Expand Down Expand Up @@ -34,7 +36,18 @@
ocr = TesseractOCR()


def data_uri_to_image(data_uri: str):
def data_uri_to_image(data_uri: str) -> np.ndarray:
"""Converts a base64 encoded data URI to an image, represented as a NumPy array.
Args:
data_uri (str): The base64 encoded image in data URI format.
Returns:
np.ndarray: The decoded image in NumPy array format.
Raises:
HTTPException: If the image decoding fails.
"""
try:
base64_data = data_uri.split(",")[1]
image_data = base64.b64decode(base64_data)
Expand All @@ -48,18 +61,40 @@ def data_uri_to_image(data_uri: str):
)


def image_to_data_uri(image: np.ndarray):
def image_to_data_uri(image: np.ndarray) -> bytes:
"""Converts an image to a base64 encoded data URI.
Args:
image (np.ndarray): The input image in NumPy array format.
Returns:
bytes: The Base64 encoded data URI representation of the image.
"""
_, encoded = cv.imencode(".png", image)
return b"data:image/png;base64," + base64.b64encode(encoded)


@app.get("/")
async def health_check():
"""Health check endpoint to verify the API is running.
Returns:
dict: A dictionary with the status of the service.
"""
return {"status": "UP"}


@app.post("/image_alignment/")
async def image_alignment(source_image: str = Form(), segmentation_template: str = Form()):
async def image_alignment(source_image: str = Form(), segmentation_template: str = Form()) -> dict:
"""Aligns a source image to a segmentation template.
Args:
source_image (str): The base64 encoded source image.
segmentation_template (str): The baSe64 encoded segmentation template image.
Returns:
dict: A dictionary containing the aligned image as a base64 encoded data URI.
"""
source_image_img = data_uri_to_image(source_image)
segmentation_template_img = data_uri_to_image(segmentation_template)

Expand All @@ -70,6 +105,19 @@ async def image_alignment(source_image: str = Form(), segmentation_template: str

@app.post("/image_file_to_text/")
async def image_file_to_text(source_image: UploadFile, segmentation_template: UploadFile, labels: str = Form()):
"""Extracts text from an image file based on a segmentation template, using OCR.
Args:
source_image (UploadFile): The uploaded source image file.
segmentation_template (UploadFile): The uploaded segmentation template file.
labels (str): The JSON-encoded labels defining segmentation templates.
Returns:
dict: A dictionary containing the OCR results for the segmented regions.
Raises:
HTTPException: If there are issues with file decoding, parsing, segmentation, or OCR.
"""
try:
source_image_np = np.frombuffer(await source_image.read(), np.uint8)
source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR)
Expand Down Expand Up @@ -117,7 +165,20 @@ async def image_file_to_text(source_image: UploadFile, segmentation_template: Up
@app.post("/image_to_text")
async def image_to_text(
source_image: str = Form(...), segmentation_template: str = Form(...), labels: str = Form(...)
):
) -> dict:
"""Extracts text from an image based on a segmentation template, using OCR.
Args:
source_image (str): The base64-encoded source image.
segmentation_template (str): The base64-encoded segmentation template.
labels (str): The JSON-encoded labels defining segmentation templates.
Returns:
dict: A dictionary containing the OCR results for the segmented regions.
Raises:
HTTPException: If there are issues with file decoding, parsing, segmentation, or OCR.
"""
try:
source_image_img = data_uri_to_image(source_image)
segmentation_template_img = data_uri_to_image(segmentation_template)
Expand Down Expand Up @@ -149,5 +210,8 @@ async def image_to_text(


def start():
"""Launched with `poetry run start` at root level"""
"""Starts the FastAPI app.
Launched with `poetry run start` at root level.
"""
uvicorn.run(app, host="0.0.0.0", port=8000, reload=False)
65 changes: 57 additions & 8 deletions OCR/ocr/services/alignment/backends/four_point_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.
"""
"""Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image."""

from pathlib import Path
import functools
Expand All @@ -10,19 +8,51 @@


class FourPointTransform:
"""A class to perform a four-point perspective transformation on an image.
This involves detecting the largest quadrilateral in the image and transforming it
to a standard rectangular form using a perspective warp.
Attributes:
image (np.ndarray): The input image as a NumPy array.
"""

def __init__(self, image: Path | np.ndarray):
"""Initializes the FourPointTransform object with an image.
The image can either be provided as a file path (Path) or a NumPy array.
Args:
image (Path | np.ndarray): The input image, either as a path to a file or as a NumPy array.
"""
if isinstance(image, np.ndarray):
self.image = image
else:
self.image = cv.imread(str(image))

@classmethod
def align(self, source_image, template_image):
def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""Aligns a source image to a template image using the four-point transform.
Args:
source_image (np.ndarray): The source image to be aligned.
template_image (np.ndarray): The template image to align to.
Returns:
np.ndarray: The transformed image.
"""
return FourPointTransform(source_image).dewarp()

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
"Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left."
"""Reorders the points of a quadrilateral from an unordered 4x2 array to a specific order of top-left, top-right, bottom-right, and bottom-left.
Args:
quadrilateral (np.ndarray): A 4x2 array representing the vertices of a quadrilateral.
Returns:
np.ndarray: A 4x2 array with the points ordered as [top-left, top-right, bottom-right, bottom-left].
"""
quadrilateral = quadrilateral.reshape(4, 2)
output_quad = np.zeros([4, 2]).astype(np.float32)
s = quadrilateral.sum(axis=1)
Expand All @@ -33,19 +63,38 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
output_quad[3] = quadrilateral[np.argmax(diff)]
return output_quad

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
def find_largest_contour(self) -> np.ndarray:
"""Finds the largest contour in the image by computing the contours and selecting the one with the greatest area.
Returns:
np.ndarray: The largest contour found in the image.
"""
contours, _ = cv.findContours(
cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE
)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
"""Simplify to a polygon with (hopefully four) vertices."""
"""Simplifies a given contour to a polygon with a reduced number of vertices, ideally four.
Args:
contour (np.ndarray): The contour to simplify.
Returns:
np.ndarray: The simplified polygon.
"""
perimeter = cv.arcLength(contour, True)
return cv.approxPolyDP(contour, 0.01 * perimeter, True)

def dewarp(self) -> np.ndarray:
"""Performs a four-point perspective transform to "dewarp" the image.
This involves detecting the largest quadrilateral, simplifying it to a polygon, and
applying a perspective warp to straighten the image into a rectangle.
Returns:
np.ndarray: The perspective-transformed (dewarped) image.
"""
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

Expand Down
107 changes: 90 additions & 17 deletions OCR/ocr/services/alignment/backends/image_homography.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
"""Aligns two images using image homography algorithms."""

from pathlib import Path

import numpy as np
import cv2 as cv


class ImageHomography:
"""A class to align two images using homography techniques.
Uses Scale-Invariant Feature Transform (SIFT) algorithm to detect keypoints and
compute descriptors for image matching, and then estimates a homography
transformation matrix to align the source image with a template image.
Attributes:
template (np.ndarray): The template image to align against, either as a path or a NumPy array.
match_ratio (float): The ratio used for Lowe's ratio test to filter good matches.
_sift (cv.SIFT): The SIFT detector used to find keypoints and descriptors.
"""

def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
"""Initializes the ImageHomography object with a template image.
Optionally include a match ratio for filtering descriptor matches; this must be between 0 and 1.
Args:
template (Path | np.ndarray): The template image, either as a file path or a NumPy array.
match_ratio (float, optional): The ratio threshold for Lowe's ratio test. Default is 0.3.
Raises:
ValueError: If `match_ratio` is not between 0 and 1.
"""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

Expand All @@ -18,24 +42,64 @@ def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
self._sift = cv.SIFT_create()

@classmethod
def align(self, source_image, template_image):
def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""Aligns a source image to a template image.
Args:
source_image (np.ndarray): The source image to align.
template_image (np.ndarray): The template image to align to.
Returns:
np.ndarray: The aligned source image.
"""
return ImageHomography(template_image).transform_homography(source_image)

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
"""Calibrates the match ratio using a self-similarity metric (not implemented).
Raises:
NotImplementedError: Since this method is not implemented.
"""
raise NotImplementedError

def compute_descriptors(self, img):
"""Compute SIFT descriptors for a target `img`."""
def compute_descriptors(self, img: np.ndarray):
"""Computes the SIFT descriptors for a given image.
These descriptors represent distinctive features in the image that can be used for matching.
Args:
img (np.ndarray): The image for which to compute descriptors.
Returns:
tuple: A 2-element tuple containing the keypoints and their corresponding descriptors.
"""
return self._sift.detectAndCompute(img, None)

def knn_match(self, descriptor_template, descriptor_query):
"""Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image."""
"""Performs k-nearest neighbors matching (k=2) between descriptors to find best homography matches.
Args:
descriptor_template (np.ndarray): The SIFT descriptors from the template image.
descriptor_query (np.ndarray): The SIFT descriptors from the query image.
Returns:
list: A list of k-nearest neighbor matches between the template and query descriptors.
"""
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED)
return matcher.knnMatch(descriptor_template, descriptor_query, 2)

def estimate_transform_matrix(self, other):
"Estimate the transformation matrix based on homography."
def estimate_transform_matrix(self, other: np.ndarray) -> np.ndarray:
"""Estimates the transformation matrix between the template image and another image.
This function detects keypoints and descriptors, matches them using k-nearest neighbors,
and applies Lowe's ratio test to filter for quality matches.
Args:
other (np.ndarray): The image to estimate the transformation matrix against.
Returns:
np.ndarray: The homography matrix that transforms the other image to align with the template image.
"""
# find the keypoints and descriptors with SIFT
kp1, descriptors1 = self.compute_descriptors(self.template)
kp2, descriptors2 = self.compute_descriptors(other)
Expand All @@ -55,17 +119,26 @@ def estimate_transform_matrix(self, other):
M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)
return M

def transform_homography(self, other, min_axis=100, matrix=None):
"""
Run the image homography pipeline against a query image.
def transform_homography(self, other: np.ndarray, min_axis=100, matrix=None) -> np.ndarray:
"""Run the full image homography pipeline against a query image.
Parameters:
min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform.
If the input image is under the axis limits, return the original input image unchanged.
matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be
estimated with `estimate_transform_matrix`.
"""
If the size of the `other` image is smaller than the minimum axis length `min_axis`,
the image is returned unchanged.
If a transformation matrix is provided, it is used directly; otherwise, the matrix is
estimated using `estimate_transform_matrix`.
Args:
other (np.ndarray): The image to be transformed.
min_axis (int, optional): The minimum axis length (in pixels) to attempt the homography transform.
If the image is smaller, it will be returned unchanged. Default is 100.
matrix (np.ndarray, optional): The homography transformation matrix to apply. If not provided,
it will be estimated.
Returns:
np.ndarray: The transformed image if homography was applied, or the original image if it is
smaller than the minimum axis size.
"""
if other.shape[0] < min_axis and other.shape[1] < min_axis:
return other

Expand Down
Loading

0 comments on commit 0f8d189

Please sign in to comment.