Skip to content

Commit

Permalink
fix: Fixed PyTorch tensor cropping and extended script support (#458)
Browse files Browse the repository at this point in the history
* feat: Added support of PyTorch to the evaluation script

* fix: Fixed cropping for PyTorch image tensors

* fix: Fixed evaluation when metric is undefined

* chore: Added CI job to run evaluation with PyTorch

* fix: Fixed edge case of double framework availability

* feat: Added support of PyTorch to the analysis script

* chore: Refactored script CI jobs

* feat: Avoids doing inference on zero-sized crops

* fix: Fixed issue when no crop is detected
  • Loading branch information
fg-mindee authored Sep 6, 2021
1 parent 3e7e9de commit ad74838
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 38 deletions.
49 changes: 43 additions & 6 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
Expand All @@ -24,7 +25,8 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
Expand All @@ -34,10 +36,27 @@ jobs:
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run analysis script
run: |
Expand All @@ -51,6 +70,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
Expand All @@ -61,7 +81,8 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
Expand All @@ -71,13 +92,29 @@ jobs:
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- name: Run evaluation script
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run evaluation script
run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10

test-collectenv:
runs-on: ${{ matrix.os }}
Expand Down
21 changes: 16 additions & 5 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
__all__ = ['estimate_orientation', 'extract_crops', 'extract_rcrops', 'get_bitmap_angle']


def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]:
"""Created cropped images from list of bounding boxes
Args:
img: input image
boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
coordinates (xmin, ymin, xmax, ymax)
channels_last: whether the channel dimensions is the last one instead of the last one
Returns:
list of cropped images
Expand All @@ -36,16 +37,26 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
_boxes = _boxes.round().astype(int)
# Add last index
_boxes[2:] += 1
return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes]
if channels_last:
return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes]
else:
return [img[:, box[1]: box[3], box[0]: box[2]] for box in _boxes]


def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List[np.ndarray]:
def extract_rcrops(
img: np.ndarray,
boxes: np.ndarray,
dtype=np.float32,
channels_last: bool = True
) -> List[np.ndarray]:
"""Created cropped images from list of rotated bounding boxes
Args:
img: input image
boxes: bounding boxes of shape (N, 5) where N is the number of boxes, and the relative
coordinates (x, y, w, h, alpha)
dtype: target data type of bounding boxes
channels_last: whether the channel dimensions is the last one instead of the last one
Returns:
list of cropped images
Expand Down Expand Up @@ -80,9 +91,9 @@ def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List
M = cv2.getAffineTransform(src_pts, dst_pts)
# Warp the rotated rectangle
if clockwise:
crop = cv2.warpAffine(img, M, (int(w), int(h)))
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(w), int(h)))
else:
crop = cv2.warpAffine(img, M, (int(h), int(w)))
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(h), int(w)))
crops.append(crop)

return crops
Expand Down
28 changes: 22 additions & 6 deletions doctr/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .detection import DetectionPredictor
from .recognition import RecognitionPredictor
from ._utils import extract_crops, extract_rcrops
from doctr.file_utils import is_torch_available
from doctr.io.elements import Word, Line, Block, Page, Document
from doctr.utils.repr import NestedObject
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes, rotate_image
Expand Down Expand Up @@ -51,20 +52,35 @@ def __call__(

# Localize text elements
boxes = self.det_predictor(pages, **kwargs)
# Check whether crop mode should be switched to channels first
crop_kwargs = {}
if len(pages) > 0 and not isinstance(pages[0], np.ndarray) and is_torch_available():
crop_kwargs['channels_last'] = False
# Crop images, rotate page if necessary
if self.doc_builder.rotated_bbox:
crops = [crop for page, (_boxes, angle) in zip(pages, boxes) for crop in
self.extract_crops_fn(rotate_image(page, -angle, False), _boxes[:, :-1])] # type: ignore[operator]
crops = [
crop for page, (_boxes, angle) in zip(pages, boxes) for crop in
self.extract_crops_fn( # type: ignore[operator]
rotate_image(page, -angle, False),
_boxes[:, :-1],
**crop_kwargs
)
]
else:
crops = [crop for page, (_boxes, _) in zip(pages, boxes) for crop in
self.extract_crops_fn(page, _boxes[:, :-1])] # type: ignore[operator]
self.extract_crops_fn(page, _boxes[:, :-1], **crop_kwargs)] # type: ignore[operator]
# Avoid sending zero-sized crops
is_kept = [all(s > 0 for s in crop.shape) for crop in crops]
crops = [crop for crop, _kept in zip(crops, is_kept) if _kept]
boxes = [box for box, _kept in zip(boxes, is_kept) if _kept]
# Identify character sequences
word_preds = self.reco_predictor(crops, **kwargs)

# Rotate back boxes if necessary
boxes, angles = zip(*boxes)
if self.doc_builder.rotated_bbox:
boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)]
if len(boxes) > 0:
boxes, angles = zip(*boxes)
if self.doc_builder.rotated_bbox:
boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)]
out = self.doc_builder(boxes, word_preds, [page.shape[:2] for page in pages])
return out

Expand Down
29 changes: 22 additions & 7 deletions scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,40 @@

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr.models import ocr_predictor
from doctr.io import DocumentFile
from doctr.file_utils import is_tf_available

# Enable GPU growth if using TF
if is_tf_available():
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
import torch


def main(args):

model = ocr_predictor(args.detection, args.recognition, pretrained=True)

if not is_tf_available():
model.det_predictor.pre_processor = model.det_predictor.pre_processor.eval()
model.det_predictor.model = model.det_predictor.model.eval()
model.reco_predictor.pre_processor = model.reco_predictor.pre_processor.eval()
model.reco_predictor.model = model.reco_predictor.model.eval()

if args.path.endswith(".pdf"):
doc = DocumentFile.from_pdf(args.path).as_images()
else:
doc = DocumentFile.from_images(args.path)

out = model(doc, training=False)
if is_tf_available():
out = model(doc, training=False)
else:
with torch.no_grad():
out = model(doc)

for page, img in zip(out.pages, doc):
page.show(img, block=not args.noblock, interactive=not args.static)
Expand Down
51 changes: 37 additions & 14 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,41 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
import numpy as np
from tqdm import tqdm

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
import numpy as np
from tqdm import tqdm

from doctr.utils.metrics import LocalizationConfusion, TextMatch, OCRMetric
from doctr import datasets
from doctr.models import ocr_predictor, extract_crops
from doctr.file_utils import is_tf_available

# Enable GPU growth if using TF
if is_tf_available():
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
import torch


def _pct(val):
return "N/A" if val is None else f"{val:.2%}"


def main(args):

predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size)

if not is_tf_available():
predictor.det_predictor.pre_processor = predictor.det_predictor.pre_processor.eval()
predictor.det_predictor.model = predictor.det_predictor.model.eval()
predictor.reco_predictor.pre_processor = predictor.reco_predictor.pre_processor.eval()
predictor.reco_predictor.model = predictor.reco_predictor.model.eval()

if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
img_folder=args.img_folder,
Expand Down Expand Up @@ -60,9 +75,17 @@ def main(args):
gt_labels = target['labels']

# Forward
out = predictor(page[None, ...], training=False)
crops = extract_crops(page, gt_boxes)
reco_out = predictor.reco_predictor(crops, training=False)
if is_tf_available():
out = predictor(page[None, ...], training=False)
crops = extract_crops(page, gt_boxes)
reco_out = predictor.reco_predictor(crops, training=False)
else:
with torch.no_grad():
out = predictor(page[None, ...])
# We directly crop on PyTorch tensors, which are in channels_first
crops = extract_crops(page, gt_boxes, channels_last=False)
reco_out = predictor.reco_predictor(crops)

if len(reco_out):
reco_words, _ = zip(*reco_out)
else:
Expand Down Expand Up @@ -111,12 +134,12 @@ def main(args):
print(f"Model Evaluation (model= {args.detection} + {args.recognition}, "
f"dataset={'OCRDataset' if args.img_folder else args.dataset})")
recall, precision, mean_iou = det_metric.summary()
print(f"Text Detection - Recall: {recall:.2%}, Precision: {precision:.2%}, Mean IoU: {mean_iou:.2%}")
print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}")
acc = reco_metric.summary()
print(f"Text Recognition - Accuracy: {acc['raw']:.2%} (unicase: {acc['unicase']:.2%})")
print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})")
recall, precision, mean_iou = e2e_metric.summary()
print(f"OCR - Recall: {recall['raw']:.2%} (unicase: {recall['unicase']:.2%}), "
f"Precision: {precision['raw']:.2%} (unicase: {precision['unicase']:.2%}), Mean IoU: {mean_iou:.2%}")
print(f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}")


def parse_args():
Expand Down

0 comments on commit ad74838

Please sign in to comment.