Skip to content

Commit

Permalink
Use aws textract for ocr (#935)
Browse files Browse the repository at this point in the history
* Use aws textract for ocr

* Use boto3 client textract directly

* boto3 req

* Remove pytesseract references
  • Loading branch information
DhruvaBansal00 authored Nov 14, 2024
1 parent 905845e commit 9a3edbb
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 70 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cohere = [
"cohere>=4.11.2"
]
minimal = [
"boto3==1.33.3",
"openai==1.45.0",
"langchain==0.2.16",
"langchain-anthropic==0.1.23",
Expand All @@ -108,6 +109,7 @@ minimal = [
]
all = [
"black",
"boto3==1.33.3",
"bumpver",
"pip-tools",
"pytest",
Expand Down
100 changes: 76 additions & 24 deletions src/autolabel/transforms/image.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,91 @@
from typing import Dict, Any, List
"""Extract text from images using OCR."""

from __future__ import annotations

from typing import Any, ClassVar

from autolabel.transforms.schema import TransformType
from autolabel.transforms import BaseTransform
from autolabel.cache import BaseCache
from autolabel.transforms import BaseTransform
from autolabel.transforms.schema import TransformType


class ImageTransform(BaseTransform):
"""This class is used to extract text from images using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'
"""Extract text from images using OCR.
This transform supports the following image formats: PNG, JPEG, TIFF, JPEG 2000, GIF, WebP, BMP, and PNM
The output columns dictionary for this class should include the keys 'content_column'
and 'metadata_column'.
This transform supports the following image formats: PNG, JPEG, TIFF, JPEG 2000, GIF,
WebP, BMP, and PNM.
"""

COLUMN_NAMES = [
COLUMN_NAMES: ClassVar[list[str]] = [
"content_column",
"metadata_column",
]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
output_columns: dict[str, Any],
file_path_column: str,
lang: str = None,
lang: str | None = None,
) -> None:
"""Initialize the ImageTransform.
Args:
cache: Cache instance to use
output_columns: Dictionary mapping output column names
file_path_column: Column containing image file paths
lang: Optional language for OCR
"""
super().__init__(cache, output_columns)
self.file_path_column = file_path_column
self.lang = lang

try:
from PIL import Image
import pytesseract
from PIL import Image

self.Image = Image
self.pytesseract = pytesseract
self.pytesseract.get_tesseract_version()
except ImportError:
raise ImportError(
"pillow and pytesseract are required to use the image transform with ocr. Please install pillow and pytesseract with the following command: pip install pillow pytesseract"
msg = (
"pillow and pytesseract required to use the image transform with ocr"
"Please install pillow and pytesseract with the following command: "
"pip install pillow pytesseract"
)
except EnvironmentError:
raise EnvironmentError(
"The tesseract engine is required to use the image transform with ocr. Please see https://tesseract-ocr.github.io/tessdoc/Installation.html for installation instructions."
raise ImportError(msg) from None
except OSError:
msg = (
"The tesseract engine is required to use the image transform with ocr. "
"Please see https://tesseract-ocr.github.io/tessdoc/Installation.html "
"for installation instructions."
)
raise OSError(msg) from None

@staticmethod
def name() -> str:
"""Get transform name.
Returns:
Transform type name
"""
return TransformType.IMAGE

def get_image_metadata(self, file_path: str):
def get_image_metadata(self, file_path: str) -> dict[str, Any]:
"""Get metadata from image file.
Args:
file_path: Path to image file
Returns:
Dictionary of image metadata
"""
try:
image = self.Image.open(file_path)
metadata = {
Expand All @@ -59,20 +97,22 @@ def get_image_metadata(self, file_path: str):
"exif": image._getexif(), # Exif metadata
}
return metadata
except Exception as e:
return {"error": str(e)}
except Exception as exc:
return {"error": str(exc)}

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""This function transforms an image into text using OCR.
async def _apply(self, row: dict[str, Any]) -> dict[str, Any]:
"""Transform an image into text using OCR.
Args:
row (Dict[str, Any]): The row of data to be transformed.
row: The row of data to be transformed
Returns:
Dict[str, Any]: The dict of output columns.
Dictionary of output columns
"""
content = self.pytesseract.image_to_string(
row[self.file_path_column], lang=self.lang
row[self.file_path_column],
lang=self.lang,
)
metadata = self.get_image_metadata(row[self.file_path_column])
transformed_row = {
Expand All @@ -81,12 +121,24 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
}
return transformed_row

def params(self) -> Dict[str, Any]:
def params(self) -> dict[str, Any]:
"""Get transform parameters.
Returns:
Dictionary of parameters
"""
return {
"output_columns": self.output_columns,
"file_path_column": self.file_path_column,
"lang": self.lang,
}

def input_columns(self) -> List[str]:
def input_columns(self) -> list[str]:
"""Get required input columns.
Returns:
List of input column names
"""
return [self.file_path_column]
Loading

0 comments on commit 9a3edbb

Please sign in to comment.