Skip to content

Commit

Permalink
Move Digi-Leap OCR ensemble code to a separate repository
Browse files Browse the repository at this point in the history
  • Loading branch information
rafelafrance committed Oct 9, 2023
1 parent b74724e commit f9a323c
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 139 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools wheel
pip install .
pip install git+https://github.com/rafelafrance/traiter.git@master#egg=traiter
- name: Test with unittest
run: |
python -m unittest discover
export MOCK_DATA=1; python -m unittest discover
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ test:
install: venv
$(PIP_INSTALL) -U pip setuptools wheel
$(PIP_INSTALL) .
$(PIP_INSTALL) git+https://github.com/rafelafrance/traiter.git@master#egg=traiter

dev: venv
source $(VENV)/bin/activate
$(PIP_INSTALL) -U pip setuptools wheel
$(PIP_INSTALL) -e .[dev]
$(PIP_INSTALL) -e ../../traiter/traiter --config-settings editable_mode=strict
pre-commit install

venv:
Expand Down
60 changes: 10 additions & 50 deletions ensemble/ocr_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import textwrap
from pathlib import Path

from pylib import const
from pylib.ocr import ocr_labels
from traiter.pylib import log

Expand All @@ -17,53 +16,27 @@ async def main():


def parse_args() -> argparse.Namespace:
# The current best ensemble
# [[, easyocr], [, tesseract], [deskew, easyocr], [deskew, tesseract],
# [binarize, tesseract], [denoise, tesseract], [pre_process], [post_process]]
description = """OCR images of labels. (Try this ensemble: -RrDdbnPp)"""

arg_parser = argparse.ArgumentParser(
description=textwrap.dedent(description), fromfile_prefix_chars="@"
fromfile_prefix_chars="@",
description=textwrap.dedent(
"""OCR images of labels. (Try this ensemble: -RrDdbnPp)"""
),
)

arg_parser.add_argument(
"--database",
required=True,
"--label-dir",
type=Path,
metavar="PATH",
help="""Path to a digi-leap database.""",
)

arg_parser.add_argument(
"--ocr-set",
required=True,
metavar="NAME",
help="""Name this OCR set.""",
help="""Directory containing the labels to OCR.""",
)

arg_parser.add_argument(
"--label-set",
"--text-dir",
type=Path,
metavar="PATH",
required=True,
metavar="NAME",
help="""Create this label set.""",
)

arg_parser.add_argument(
"--classes",
choices=const.CLASSES[1:],
default=["Typewritten"],
type=str,
nargs="*",
help="""Keep labels if they fall into any of these categories.
(default: %(default)s)""",
)

arg_parser.add_argument(
"--label-conf",
type=float,
default=0.25,
help="""Only OCR labels that have a confidence >= to this. Set it to 0.0 to
get all of the labels. (default: %(default)s)""",
help="""Output OCR text files to this directory.""",
)

arg_parser.add_argument(
Expand Down Expand Up @@ -148,19 +121,6 @@ def parse_args() -> argparse.Namespace:
""",
)

arg_parser.add_argument(
"--notes",
default="",
metavar="TEXT",
help="""Notes about this run. Enclose them in quotes.""",
)

arg_parser.add_argument(
"--limit",
type=int,
help="""Limit to this many labels.""",
)

args = arg_parser.parse_args()
return args

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
"""Build an expedition to determine the quality of OCR output."""
import argparse
import textwrap
from pathlib import Path
Expand All @@ -19,7 +18,7 @@ def main():


def parse_args() -> argparse.Namespace:
description = """Build the 'Is Correction Needed?' expedition"""
description = """Build the "Is Correction Needed?" expedition"""

arg_parser = argparse.ArgumentParser(
description=textwrap.dedent(description), fromfile_prefix_chars="@"
Expand Down
196 changes: 196 additions & 0 deletions ensemble/pylib/box_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Common functions for bounding boxes.
This module mostly contains variants of bounding box non-maximum suppression (NMS).
"""
import numpy as np
import torch


def iou(box1, box2):
"""Calculate the intersection over union of a pair of boxes.
The boxes are expected to be in [x_min, y_min, x_max, y_max] (pascal) format.
Modified from Matlab code:
https://www.computervisionblog.com/2011/08/blazing-fast-nmsm-from-exemplar-svm.html
"""
# These are inner (overlapping) box dimensions, so we want
# the maximum of the minimums and the minimum of the maximums
x_min = max(box1[0], box2[0])
y_min = max(box1[1], box2[1])
x_max = min(box1[2], box2[2])
y_max = min(box1[3], box2[3])

inter = max(0, x_max - x_min) * max(0, y_max - y_min)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
return inter / (area1 + area2 - inter)


def find_box_groups(boxes, threshold=0.3, scores=None):
"""Find overlapping sets of bounding boxes.
Groups are by abs() where the positive value indicates the "best" box in the
group and negative values indicate all other boxes in the group.
"""
if len(boxes) == 0:
return np.array([])

if boxes.dtype.kind == "i":
boxes = boxes.astype("float64")

# Simplify access to box components
x0, y0, x1, y1 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]

area = np.maximum(0.0, x1 - x0) * np.maximum(0.0, y1 - y0)

idx = scores if scores else area
idx = idx.argsort()

overlapping = np.zeros_like(idx)
group = 0
while len(idx) > 0:
group += 1

# Pop the largest box
curr = idx[-1]
idx = idx[:-1]

overlapping[curr] = group

# Get interior (overlap) coordinates
xx0 = np.maximum(x0[curr], x0[idx])
yy0 = np.maximum(y0[curr], y0[idx])
xx1 = np.minimum(x1[curr], x1[idx])
yy1 = np.minimum(y1[curr], y1[idx])

# Get the intersection over the union (IOU) with the current box
iou_ = np.maximum(0.0, xx1 - xx0) * np.maximum(0.0, yy1 - yy0)
iou_ /= area[idx] + area[curr] - iou_

# Find IOUs larger than threshold & group them
iou_ = np.where(iou_ >= threshold)[0]
overlapping[idx[iou_]] = -group

# Remove all indices in an IOU group
idx = np.delete(idx, iou_)

return overlapping


def overlapping_boxes(boxes, threshold=0.3, scores=None):
"""Group overlapping boxes."""
groups = find_box_groups(boxes, threshold=threshold, scores=scores)
return np.abs(groups)


def nms(boxes, threshold=0.3, scores=None):
"""Remove overlapping boxes via non-maximum suppression."""
groups = find_box_groups(boxes, threshold=threshold, scores=scores)
reduced = np.argwhere(groups > 0).squeeze(1)
return boxes[reduced]


def small_box_overlap(boxes, threshold=0.5):
"""Get overlapping box groups using the intersection over area of the smaller box.
This is analogous to the classic "non-maximum suppression" algorithm except:
1. I return the groups of boxes rather than prune the overlapping ones.
2. The measure is: intersection over the area of the smaller box.
Also see: small_box_suppression().
"""
if len(boxes) == 0:
return np.array([])

if boxes.dtype.kind == "i":
boxes = boxes.astype("float64")

# Simplify access to box components
x0, y0, x1, y1 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]

area = np.maximum(0.0, x1 - x0) * np.maximum(0.0, y1 - y0) + 1e-8

idx = area.argsort()

overlapping = np.zeros_like(idx)
group = 0
while len(idx) > 0:
group += 1

# Pop the largest box
curr = idx[-1]
idx = idx[:-1]

overlapping[curr] = group

# Get interior (overlap) coordinates
xx0 = np.maximum(x0[curr], x0[idx])
yy0 = np.maximum(y0[curr], y0[idx])
xx1 = np.minimum(x1[curr], x1[idx])
yy1 = np.minimum(y1[curr], y1[idx])

# Get the intersection as a fraction of the smaller box
inter = np.maximum(0.0, xx1 - xx0) * np.maximum(0.0, yy1 - yy0)
inter /= area[idx]

# Find overlaps larger than threshold & group them
inter = np.where(inter >= threshold)[0]
overlapping[idx[inter]] = group

# Remove all indices in an overlap group
idx = np.delete(idx, inter)

return overlapping


def small_box_suppression(boxes, threshold=0.9, eps=1e-8):
"""Remove overlapping small bounding boxes, analogous to non-maximum suppression.
I can't just remove all small boxes because there are genuinely small labels.
So I use the intersection of the boxes (over the area of the smaller box) to
weed out small boxes that are covered by a bigger bounding box. Just using
non-maximum suppression is not going to work because it uses the intersection
over union (IoU) as a filtering metric and a tiny box contained in a larger box
will not have an IoU over the threshold used for NMS. Using the intersection
over the area of the smaller box gets around the issue.
"""
if boxes.numel() == 0:
return torch.empty((0, 4), dtype=torch.float32)

# Simplify access to box components
x0, y0, x1, y1 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]

area1 = torch.maximum(torch.tensor([0.0]), x1 - x0)
area2 = torch.maximum(torch.tensor([0.0]), y1 - y0)
area = area1 * area2 + eps

idx = area.argsort()

keep = torch.zeros(idx.numel(), dtype=torch.bool)

while idx.numel() > 0:
# Pop the largest box
curr = idx[-1]
idx = idx[:-1]

keep[curr] = True

# Get interior (overlap) coordinates
xx0 = torch.maximum(x0[curr], x0[idx])
yy0 = torch.maximum(y0[curr], y0[idx])
xx1 = torch.minimum(x1[curr], x1[idx])
yy1 = torch.minimum(y1[curr], y1[idx])

# Get the intersection as a fraction of the smaller box
inter1 = torch.maximum(torch.tensor([0.0]), xx1 - xx0)
inter2 = torch.maximum(torch.tensor([0.0]), yy1 - yy0)
inter = (inter1 * inter2) / area[idx]

# Find overlaps larger than threshold & delete them
inter = torch.where(inter >= threshold)[0]
mask = torch.ones(idx.numel(), dtype=torch.bool)
mask[inter] = False
idx = idx[mask]

return torch.where(keep)[0]
25 changes: 25 additions & 0 deletions ensemble/pylib/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
import sys
from os.path import basename, splitext


def setup_logger(level=logging.INFO):
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)


def module_name() -> str:
return splitext(basename(sys.argv[0]))[0]


def started() -> None:
setup_logger()
logging.info("=" * 80)
logging.info("%s started", module_name())


def finished() -> None:
logging.info("%s finished", module_name())
5 changes: 2 additions & 3 deletions ensemble/pylib/ocr/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from traiter.pylib.spell_well import SpellWell

from ensemble.pylib.builder import label_builder
from ensemble.pylib.builder.line_align import line_align_py

from ..builder import label_builder
from ..builder.line_align import char_sub_matrix as subs
from ..builder.line_align import line_align_py # noqa
from . import label_transformer as lt
from . import ocr_runner

Expand Down
12 changes: 5 additions & 7 deletions ensemble/pylib/ocr/ocr_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@
from tqdm import tqdm
from traiter.pylib.spell_well import SpellWell

from ensemble import const
from ensemble.pylib.builder import label_builder, line_align_py

from .. import db
from .. import const, db
from ..builder import label_builder, line_align_py
from ..builder.line_align import char_sub_matrix as subs
from . import label_transformer, ocr_runner

IMAGE_TRANSFORMS = ["", "deskew_full", "binarize_full", "denoise_full"]


async def ocr(gold_std):
def ocr(gold_std):
golden = []
for gold in tqdm(gold_std, desc="ocr"):
gold["gold_text"] = " ".join(gold["gold_text"].split())
Expand All @@ -30,10 +28,10 @@ async def ocr(gold_std):
for transform in IMAGE_TRANSFORMS:
image = transform_image(original, transform)

text = await ocr_runner.easy_text(image)
text = ocr_runner.easy_text(image)
gold["pipe_text"][f"[{transform}, easyocr]"] = " ".join(text.split())

text = await ocr_runner.tess_text(image)
text = ocr_runner.tess_text(image)
gold["pipe_text"][f"[{transform}, tesseract]"] = " ".join(text.split())

golden.append(gold)
Expand Down
Loading

0 comments on commit f9a323c

Please sign in to comment.