Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OCR recognition model #158

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

OCR recognition model #158

wants to merge 7 commits into from

Conversation

sokovninn
Copy link
Member

@sokovninn sokovninn commented Jan 23, 2025

New OCR recognition model, loss, metric and visualizer

The most important changes are summarized below:

Losses:

  • Introduced CTCLoss with optional focal loss weighting in luxonis_train/attached_modules/losses/ctc_loss.py and updated __init__.py to include CTCLoss. [1] [2] [3]
  • Updated luxonis_train/attached_modules/losses/README.md to document CTCLoss.

Metrics:

  • Added OCRAccuracy metric for OCR tasks in luxonis_train/attached_modules/metrics/ocr_accuracy.py and updated __init__.py to include OCRAccuracy. [1] [2] [3]
  • Updated luxonis_train/attached_modules/metrics/README.md to document OCRAccuracy.

Visualizers:

  • Introduced OCRVisualizer for visualizing OCR tasks in luxonis_train/attached_modules/visualizers/ocr_visualizer.py and updated __init__.py to include OCRVisualizer. [1] [2] [3]
  • Updated luxonis_train/attached_modules/visualizers/README.md to document OCRVisualizer.

Predefined Models:

  • Added OCRRecognitionModel to luxonis_train/config/predefined_models/__init__.py and updated README.md to document its components and parameters. [1] [2] [3]

Toy dataset creation example

def toy_ocr_generator():
    im_paths = glob.glob("*.png")
    labels = [os.path.splitext(os.path.basename(path))[0] for path in im_paths]
    for path, label in tqdm(zip(im_paths, labels)):
        if len(label):
            yield {
                "file": path,
                "annotation": {
                    "metadata": {"text": label, "text_length": len(label)},
                },
            }

Examples from the overfitted model on the toy dataset

ocr_recognition-OCRCTCHead_OCRVisualizer_3
ocr_recognition-OCRCTCHead_OCRVisualizer_2
ocr_recognition-OCRCTCHead_OCRVisualizer_1
ocr_recognition-OCRCTCHead_OCRVisualizer_0

@sokovninn sokovninn requested a review from a team as a code owner January 23, 2025 21:36
@sokovninn sokovninn requested review from kozlov721, klemen1999, tersekmatija and conorsim and removed request for a team January 23, 2025 21:36
@github-actions github-actions bot added documentation Improvements or additions to documentation enhancement New feature or request labels Jan 23, 2025
@sokovninn
Copy link
Member Author

sokovninn commented Jan 23, 2025

Possible improvements include:

  • Adding more advanced OCR metrics
  • Adding a temporal NRTR head together with NRTRLoss
  • Improving visualization
  • Adding a large variant
  • Improving encoder to handle more edge cases
  • Adding a beam search decoder
  • Adding OCR-specific augmentations
  • Adding OCR detection model (same backbone)

Copy link
Collaborator

@klemen1999 klemen1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, left some comments. One thing that we want to also make sure is the integration with HubAI and depthai-nodes - the archived model should have correct archive data so that the parser can parse it.

target = [chr(int(char.item())) for char in target]
target = "".join(target)
target_strings.append(target)
print(target_strings)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not needed (also another print in the forward())


## `OCRRecognitionModel`

FPS of the `OCRRecognitionModel` on different devices with image size 48x320:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this under ### Performance Metrics section to be inline with others. And before that add just a very short description, for this model it would be valuable to note how the dataset needs to be structured - which annotations need to be present

class OCRRecognitionModel(BasePredefinedModel):
"""A predefined model for OCR recognition tasks."""

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a variant even if it is just for "light" for now as it simpler then to integrate on the HubAI side (same paramter as all other predefined models)

return kernel * t, beta - running_mean * gamma / std # type: ignore


class SELayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the same block as SqueezeExciteBlock with approx_sigmoid=True. If so we can remove it here and replace with that one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I didn't notice that they were the same.

logger = logging.getLogger(__name__)


NET_CONFIG_det = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we can put this into variants.py (similarly as we do with e.g. EfficientRep) so the code is a bit more clean

from luxonis_train.utils import OCRDecoder, OCREncoder


def get_para_bias_attr(l2_decay: float, k: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Move this to the bottom so class is on top of the file, IMO cleaner

class OCRCTCHead(BaseHead[Tensor, Tensor]):
in_channels: int
tasks: list[TaskType] = [TaskType.CLASSIFICATION]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the exported model look the same as PaddleOCR? Can we use the same parser for it from depthai-nodes (this one)? Ideally we want to support the full integration with HubAI

return x


class Block(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Can we name it something more descriptive?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants