diff --git a/.gitignore b/.gitignore index c64ea4da44..157ee2a791 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,8 @@ dmypy.json doctr/version.py logs/ wandb/ + +# Checkpoints +*.pt +*.pb +*.index diff --git a/README.md b/README.md index 6572033cee..7afdb30e8a 100644 --- a/README.md +++ b/README.md @@ -114,9 +114,9 @@ We try to keep framework-specific dependencies to a minimum. You can install fra ```shell # for TensorFlow -pip install python-doctr[tf] +pip install "python-doctr[tf]" # for PyTorch -pip install python-doctr[torch] +pip install "python-doctr[torch]" ``` ### Developer mode diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 39163d4ef3..45b7e27b81 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -19,6 +19,7 @@ Here are all datasets that are available through docTR: .. autoclass:: OCRDataset .. autoclass:: CharacterGenerator .. autoclass:: DocArtefacts +.. autoclass:: IIIT5K Data Loading diff --git a/docs/source/installing.rst b/docs/source/installing.rst index bb5a7a527f..8197df660d 100644 --- a/docs/source/installing.rst +++ b/docs/source/installing.rst @@ -40,9 +40,9 @@ We strive towards reducing framework-specific dependencies to a minimum, but som .. code:: bash # for TensorFlow - pip install python-doctr[tf] + pip install "python-doctr[tf]" # for PyTorch - pip install python-doctr[torch] + pip install "python-doctr[torch]" Via Git diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py index 0c9bfc8b81..4cbac79cac 100644 --- a/doctr/datasets/__init__.py +++ b/doctr/datasets/__init__.py @@ -5,6 +5,7 @@ from .detection import * from .doc_artefacts import * from .funsd import * +from .iiit5k import * from .ocr import * from .recognition import * from .sroie import * diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index f5a39e9992..f2a4a3cecb 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -6,7 +6,7 @@ import os from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union -from zipfile import ZipFile +import shutil from doctr.utils.data import download_from_url @@ -100,7 +100,6 @@ def __init__( archive_path = Path(archive_path) dataset_path = archive_path.parent.joinpath(archive_path.stem) if not dataset_path.is_dir() or overwrite: - with ZipFile(archive_path, 'r') as f: - f.extractall(path=dataset_path) + shutil.unpack_archive(archive_path, dataset_path) super().__init__(dataset_path if extract_archive else archive_path, fp16) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py new file mode 100644 index 0000000000..4e9cb36902 --- /dev/null +++ b/doctr/datasets/iiit5k.py @@ -0,0 +1,82 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import scipy.io as sio + +from .datasets import VisionDataset + +__all__ = ['IIIT5K'] + + +class IIIT5K(VisionDataset): + """IIIT-5K character-level localization dataset from + `"BMVC 2012 Scene Text Recognition using Higher Order Language Priors" + `_. + + Example:: + >>> # NOTE: this dataset is for character-level localization + >>> from doctr.datasets import IIIT5K + >>> train_set = IIIT5K(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + sample_transforms: composable transformations that will be applied to each image + rotated_bbox: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz' + SHA256 = '7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e' + + def __init__( + self, + train: bool = True, + sample_transforms: Optional[Callable[[Any], Any]] = None, + rotated_bbox: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__(url=self.URL, file_name='IIIT5K-Word-V3.tar', + file_hash=self.SHA256, extract_archive=True, **kwargs) + self.sample_transforms = sample_transforms + self.train = train + + # Load mat data + tmp_root = os.path.join(self.root, 'IIIT5K') + mat_file = 'trainCharBound' if self.train else 'testCharBound' + mat_data = sio.loadmat(os.path.join(tmp_root, f'{mat_file}.mat'))[mat_file][0] + + self.data: List[Tuple[Path, Dict[str, Any]]] = [] + np_dtype = np.float16 if self.fp16 else np.float32 + + for img_path, label, box_targets in mat_data: + _raw_path = img_path[0] + _raw_label = label[0] + + # File existence check + if not os.path.exists(os.path.join(tmp_root, _raw_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}") + + if rotated_bbox: + # x_center, y_center, w, h, alpha = 0 + box_targets = [[box[0] + box[2] / 2, box[1] + box[3] / 2, box[2], box[3], 0] for box in box_targets] + else: + # x, y, width, height -> xmin, ymin, xmax, ymax + box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] + + # label are casted to list where each char corresponds to the character's bounding box + self.data.append((_raw_path, dict(boxes=np.asarray( + box_targets, dtype=np_dtype), labels=list(_raw_label)))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index ab73d529b4..9ea6d8a308 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -30,6 +30,8 @@ def test_visiondataset(): ['CORD', False, [512, 512], 100, False], ['DocArtefacts', True, [512, 512], 2700, False], ['DocArtefacts', False, [512, 512], 300, True], + ['IIIT5K', True, [32, 128], 2000, True], + ['IIIT5K', False, [32, 128], 3000, False], ], ) def test_dataset(dataset_name, train, input_size, size, rotate): diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 4c0b486c79..63e37cfab8 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -20,6 +20,8 @@ ['CORD', False, [512, 512], 100, False], ['DocArtefacts', True, [512, 512], 2700, False], ['DocArtefacts', False, [512, 512], 300, True], + ['IIIT5K', True, [32, 128], 2000, True], + ['IIIT5K', False, [32, 128], 3000, False], ], ) def test_dataset(dataset_name, train, input_size, size, rotate):