diff --git a/docs/source/index.rst b/docs/source/index.rst index 980aa2e3a8..e0c3fc78fb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -62,6 +62,7 @@ Supported datasets * IC03 from `ICDAR 2003 `_. * IC13 from `ICDAR 2013 `_. * IMGUR5K from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" `_. +* MJSynth from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" `_. .. toctree:: diff --git a/docs/source/modules/datasets.rst b/docs/source/modules/datasets.rst index e40b1c506a..d9b07df3e0 100644 --- a/docs/source/modules/datasets.rst +++ b/docs/source/modules/datasets.rst @@ -27,6 +27,7 @@ Public datasets .. autoclass:: IC03 .. autoclass:: IC13 .. autoclass:: IMGUR5K +.. autoclass:: MJSynth docTR synthetic datasets ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py index cd187271b1..92e1ff6831 100644 --- a/doctr/datasets/__init__.py +++ b/doctr/datasets/__init__.py @@ -9,6 +9,7 @@ from .ic13 import * from .iiit5k import * from .imgur5k import * +from .mjsynth import * from .ocr import * from .recognition import * from .sroie import * diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py new file mode 100644 index 0000000000..820e06a7b5 --- /dev/null +++ b/doctr/datasets/mjsynth.py @@ -0,0 +1,69 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +from tqdm import tqdm + +from .datasets import AbstractDataset + +__all__ = ["MJSynth"] + + +class MJSynth(AbstractDataset): + """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" + `_. + + >>> # NOTE: This is a pure recognition dataset without bounding box labels. + >>> # NOTE: You need to download the dataset. + >>> from doctr.datasets import MJSynth + >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", + >>> train=True) + >>> img, target = train_set[0] + >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") + >>> train=False) + >>> img, target = test_set[0] + + Args: + img_folder: folder with all the images of the dataset + label_path: path to the file with the labels + train: whether the subset should be the training one + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError( + f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + self.data: List[Tuple[str, Dict[str, Any]]] = [] + self.train = train + + with open(label_path) as f: + img_paths = f.readlines() + + train_samples = int(len(img_paths) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + for path in tqdm(iterable=img_paths[set_slice], desc='Unpacking MJSynth', total=len(img_paths[set_slice])): + label = [path.split('_')[1]] + img_path = os.path.join(img_folder, path[2:]).strip() + + self.data.append((img_path, dict(labels=label))) + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/tests/conftest.py b/tests/conftest.py index 1d34a9566f..f6cbb42f7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -599,3 +599,29 @@ def mock_ic03_dataset(tmpdir_factory, mock_image_stream): archive_path = root.join('ic03_train.zip') shutil.make_archive(root.join('ic03_train'), 'zip', str(ic03_root)) return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_mjsynth_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + mjsynth_root = root.mkdir('mjsynth') + image_folder = mjsynth_root.mkdir("images") + label_file = mjsynth_root.join("imlist.txt") + labels = [ + "./mjsynth/images/12_I_34.jpg\n", + "./mjsynth/images/12_am_34.jpg\n", + "./mjsynth/images/12_a_34.jpg\n", + "./mjsynth/images/12_Jedi_34.jpg\n", + "./mjsynth/images/12_!_34.jpg\n", + ] + + with open(label_file, "w") as f: + for label in labels: + f.write(label) + + file = BytesIO(mock_image_stream) + for i in ['I', 'am', 'a', 'Jedi', '!']: + fn = image_folder.join(f"12_{i}_34.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + return str(root), str(label_file) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index a8f015a177..f4b96ae389 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -505,3 +505,17 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): _validate_dataset_recognition_part(ds, input_size) else: _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 4d32ec7c50..c7f032713a 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -490,3 +490,17 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): _validate_dataset_recognition_part(ds, input_size) else: _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size)