From 29b5d1c41a6ada8659e6f000c3192e84a747ccc3 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Sun, 26 Dec 2021 18:48:21 +0100 Subject: [PATCH 1/2] feat: Added text detection evaluation scripts --- references/detection/evaluate_pytorch.py | 160 ++++++++++++++++++++ references/detection/evaluate_tensorflow.py | 138 +++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 references/detection/evaluate_pytorch.py create mode 100644 references/detection/evaluate_tensorflow.py diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py new file mode 100644 index 0000000000..acf6779c9e --- /dev/null +++ b/references/detection/evaluate_pytorch.py @@ -0,0 +1,160 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import logging +import multiprocessing as mp +import time +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from torchvision.transforms import Normalize +from tqdm import tqdm + +from doctr import datasets +from doctr import transforms as T +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + targets = [t['boxes'] for t in targets] + if amp: + with torch.cuda.amp.autocast(): + out = model(images, targets, return_boxes=True) + else: + out = model(images, targets, return_boxes=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + + val_loss += out['loss'].item() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Load docTR model + model = detection.__dict__[args.arch]( + pretrained=not isinstance(args.resume, str), + assume_straight_pages=not args.rotation + ).eval() + + if isinstance(args.size, int): + input_shape = (args.size, args.size) + else: + input_shape = model.cfg['input_shape'][-2:] + mean, std = model.cfg['mean'], model.cfg['std'] + + st = time.time() + ds = datasets.__dict__[args.dataset]( + train=True, + download=True, + rotated_bbox=args.rotation, + sample_transforms=T.Resize(input_shape), + ) + # Monkeypatch + subfolder = ds.root.split("/")[-2:] + ds.root = str(Path(ds.root).parent.parent) + ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data] + _ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation) + subfolder = _ds.root.split("/")[-2:] + ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(ds), + pin_memory=torch.cuda.is_available(), + collate_fn=ds.collate_fn, + ) + print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " + f"{len(test_loader)} batches)") + + batch_transforms = Normalize(mean=mean, std=std) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape) + + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='docTR evaluation script for text detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-detection model to evaluate') + parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for evaluation') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--size', type=int, default=None, help='model input size, H = W') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='inference with rotated bbox') + parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py new file mode 100644 index 0000000000..095c61a726 --- /dev/null +++ b/references/detection/evaluate_tensorflow.py @@ -0,0 +1,138 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import multiprocessing as mp +import time +from pathlib import Path + +import tensorflow as tf +from tensorflow.keras import mixed_precision +from tqdm import tqdm + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr import datasets +from doctr import transforms as T +from doctr.datasets import DataLoader +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion + + +def evaluate(model, val_loader, batch_transforms, val_metric): + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + images = batch_transforms(images) + targets = [t['boxes'] for t in targets] + out = model(images, targets, training=False, return_boxes=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + + val_loss += out['loss'].numpy() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + # AMP + if args.amp: + mixed_precision.set_global_policy('mixed_float16') + + input_shape = (args.size, args.size, 3) if isinstance(args.size, int) else None + + # Load docTR model + model = detection.__dict__[args.arch]( + pretrained=isinstance(args.resume, str), + assume_straight_pages=not args.rotation, + input_shape=input_shape, + ) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + model.load_weights(args.resume).expect_partial() + + input_shape = model.cfg['input_shape'] if input_shape is None else input_shape + mean, std = model.cfg['mean'], model.cfg['std'] + + st = time.time() + ds = datasets.__dict__[args.dataset]( + train=True, + download=True, + rotated_bbox=args.rotation, + sample_transforms=T.Resize(input_shape[:2]), + ) + # Monkeypatch + subfolder = ds.root.split("/")[-2:] + ds.root = str(Path(ds.root).parent.parent) + ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data] + _ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation) + subfolder = _ds.root.split("/")[-2:] + ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + shuffle=False, + ) + print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " + f"{len(test_loader)} batches)") + + batch_transforms = T.Normalize(mean=mean, std=std) + + # Metrics + metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape[:2]) + + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='docTR evaluation script for text detection (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-detection model to evaluate') + parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for evaluation') + parser.add_argument('--size', type=int, default=None, help='model input size, H = W') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='inference with rotated bbox') + parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) From e1b4f16cca9a6c49d1467e1ad2845f5fa5215787 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Sun, 26 Dec 2021 18:48:32 +0100 Subject: [PATCH 2/2] docs: Added text detection perf results --- references/detection/results.csv | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 references/detection/results.csv diff --git a/references/detection/results.csv b/references/detection/results.csv new file mode 100644 index 0000000000..3eb6cd2da4 --- /dev/null +++ b/references/detection/results.csv @@ -0,0 +1,9 @@ +architecture,input_shape,framework,test_set,recall,precision,mean_iou +db_resnet50,"(1024, 1024)",tensorflow,funsd,0.8121,0.8665,0.6681 +db_resnet50,"(1024, 1024)",tensorflow,cord,0.9245,0.8962,0.7457 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,funsd,0.783,0.828,0.6396 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,cord,0.8098,0.6657,0.5978 +db_resnet50,"(1024, 1024)",pytorch,funsd,0.7917,0.863,0.6652 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,funsd,0.8006,0.841,0.6476 +db_resnet50,"(1024, 1024)",pytorch,cord,0.9296,0.9123,0.7654 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,cord,0.8053,0.6653,0.5976