Skip to content

Commit

Permalink
[datasets] Add MJSynth (Synth90K) (#827)
Browse files Browse the repository at this point in the history
* backup

* add mjsynth loader

* apply changes

* rename

* update

* update

* fix tests
  • Loading branch information
felixdittrich92 authored Apr 28, 2022
1 parent 56b914c commit f9a1912
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Supported datasets
* IC03 from `ICDAR 2003 <http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions>`_.
* IC13 from `ICDAR 2013 <http://dagdata.cvc.uab.es/icdar2013competition/>`_.
* IMGUR5K from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" <https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset>`_.
* MJSynth from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" <https://www.robots.ox.ac.uk/~vgg/data/text/>`_.


.. toctree::
Expand Down
1 change: 1 addition & 0 deletions docs/source/modules/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Public datasets
.. autoclass:: IC03
.. autoclass:: IC13
.. autoclass:: IMGUR5K
.. autoclass:: MJSynth

docTR synthetic datasets
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ic13 import *
from .iiit5k import *
from .imgur5k import *
from .mjsynth import *
from .ocr import *
from .recognition import *
from .sroie import *
Expand Down
69 changes: 69 additions & 0 deletions doctr/datasets/mjsynth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (C) 2021-2022, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
from typing import Any, Dict, List, Tuple

from tqdm import tqdm

from .datasets import AbstractDataset

__all__ = ["MJSynth"]


class MJSynth(AbstractDataset):
"""MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition"
<https://www.robots.ox.ac.uk/~vgg/data/text/>`_.
>>> # NOTE: This is a pure recognition dataset without bounding box labels.
>>> # NOTE: You need to download the dataset.
>>> from doctr.datasets import MJSynth
>>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
>>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt",
>>> train=True)
>>> img, target = train_set[0]
>>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
>>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt")
>>> train=False)
>>> img, target = test_set[0]
Args:
img_folder: folder with all the images of the dataset
label_path: path to the file with the labels
train: whether the subset should be the training one
**kwargs: keyword arguments from `AbstractDataset`.
"""

def __init__(
self,
img_folder: str,
label_path: str,
train: bool = True,
**kwargs: Any,
) -> None:
super().__init__(img_folder, **kwargs)

# File existence check
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(
f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

self.data: List[Tuple[str, Dict[str, Any]]] = []
self.train = train

with open(label_path) as f:
img_paths = f.readlines()

train_samples = int(len(img_paths) * 0.9)
set_slice = slice(train_samples) if self.train else slice(train_samples, None)

for path in tqdm(iterable=img_paths[set_slice], desc='Unpacking MJSynth', total=len(img_paths[set_slice])):
label = [path.split('_')[1]]
img_path = os.path.join(img_folder, path[2:]).strip()

self.data.append((img_path, dict(labels=label)))

def extra_repr(self) -> str:
return f"train={self.train}"
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,29 @@ def mock_ic03_dataset(tmpdir_factory, mock_image_stream):
archive_path = root.join('ic03_train.zip')
shutil.make_archive(root.join('ic03_train'), 'zip', str(ic03_root))
return str(archive_path)


@pytest.fixture(scope="session")
def mock_mjsynth_dataset(tmpdir_factory, mock_image_stream):
root = tmpdir_factory.mktemp('datasets')
mjsynth_root = root.mkdir('mjsynth')
image_folder = mjsynth_root.mkdir("images")
label_file = mjsynth_root.join("imlist.txt")
labels = [
"./mjsynth/images/12_I_34.jpg\n",
"./mjsynth/images/12_am_34.jpg\n",
"./mjsynth/images/12_a_34.jpg\n",
"./mjsynth/images/12_Jedi_34.jpg\n",
"./mjsynth/images/12_!_34.jpg\n",
]

with open(label_file, "w") as f:
for label in labels:
f.write(label)

file = BytesIO(mock_image_stream)
for i in ['I', 'am', 'a', 'Jedi', '!']:
fn = image_folder.join(f"12_{i}_34.jpg")
with open(fn, 'wb') as f:
f.write(file.getbuffer())
return str(root), str(label_file)
14 changes: 14 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,17 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset):
_validate_dataset_recognition_part(ds, input_size)
else:
_validate_dataset(ds, input_size, is_polygons=rotate)


# NOTE: following datasets are only for recognition task

def test_mjsynth_dataset(mock_mjsynth_dataset):
input_size = (32, 128)
ds = datasets.MJSynth(
*mock_mjsynth_dataset,
img_transforms=Resize(input_size, preserve_aspect_ratio=True),
)

assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples
assert repr(ds) == f"MJSynth(train={True})"
_validate_dataset_recognition_part(ds, input_size)
14 changes: 14 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,17 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset):
_validate_dataset_recognition_part(ds, input_size)
else:
_validate_dataset(ds, input_size, is_polygons=rotate)


# NOTE: following datasets are only for recognition task

def test_mjsynth_dataset(mock_mjsynth_dataset):
input_size = (32, 128)
ds = datasets.MJSynth(
*mock_mjsynth_dataset,
img_transforms=Resize(input_size, preserve_aspect_ratio=True),
)

assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples
assert repr(ds) == f"MJSynth(train={True})"
_validate_dataset_recognition_part(ds, input_size)

0 comments on commit f9a1912

Please sign in to comment.