Skip to content

Commit

Permalink
apply changes
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 24, 2021
1 parent 73c4c97 commit e719b33
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 30 deletions.
4 changes: 2 additions & 2 deletions doctr/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
np_dtype = np.float32

for img_path, word_boxes, txt in tqdm(iterable=zip(paths, boxes, labels),
desc='Load SynthText', total=len(paths)):
desc='Loading SynthText...', total=len(paths)):

# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path[0])):
Expand All @@ -72,7 +72,7 @@ def __init__(
mins = word_boxes.min(axis=1)
maxs = word_boxes.max(axis=1)
box_targets = np.concatenate(
((mins + maxs) / 2, maxs - mins, np.expand_dims(np.zeros(word_boxes.shape[0]), axis=1)), axis=1)
((mins + maxs) / 2, maxs - mins, np.zeros((word_boxes.shape[0], 1))), axis=1)
else:
# xmin, ymin, xmax, ymax
box_targets = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1)
Expand Down
23 changes: 9 additions & 14 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from doctr import datasets
from doctr.transforms import Resize
from doctr.utils.data import download_from_url


def test_visiondataset():
Expand Down Expand Up @@ -33,25 +32,21 @@ def test_visiondataset():
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
['SynthText', True, [512, 512], 27, True], # org: 772875
['SynthText', False, [512, 512], 3, False], # org: 85875
['SynthText', True, [512, 512], 27, True], # Actual set has 772875 samples
['SynthText', False, [512, 512], 3, False], # Actual set has 85875 samples
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):

if dataset_name.lower() == 'synthtext':
download = False

# Download the subsample of SynthText dataset and save in cache
URL = 'https://github.com/mindee/doctr/releases/download/v0.4.1/synthtext_samples-89fd1445.zip'
FILE_HASH = '89fd1445457b9ad8391e17620c6ae1b45134be2bf5449f36e7e4275176cc16ac'
FILE_NAME = 'SynthText.zip'
download_from_url(URL, FILE_NAME, FILE_HASH, cache_subdir='datasets')
else:
download = True
if dataset_name.lower() == "synthtext":
# Monkeypatch the class to download a subsample
datasets.__dict__[
dataset_name
].URL = 'https://github.com/mindee/doctr/releases/download/v0.4.1/synthtext_samples-89fd1445.zip'
datasets.__dict__[dataset_name].SHA256 = '89fd1445457b9ad8391e17620c6ae1b45134be2bf5449f36e7e4275176cc16ac'

ds = datasets.__dict__[dataset_name](
train=train, download=download, sample_transforms=Resize(input_size), rotated_bbox=rotate,
train=train, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
)

assert len(ds) == size
Expand Down
23 changes: 9 additions & 14 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from doctr import datasets
from doctr.datasets import DataLoader
from doctr.transforms import Resize
from doctr.utils.data import download_from_url


@pytest.mark.parametrize(
Expand All @@ -23,25 +22,21 @@
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
['SynthText', True, [512, 512], 27, True], # org: 772875
['SynthText', False, [512, 512], 3, False], # org: 85875
['SynthText', True, [512, 512], 27, True], # Actual set has 772875 samples
['SynthText', False, [512, 512], 3, False], # Actual set has 85875 samples
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):

if dataset_name.lower() == 'synthtext':
download = False

# Download the subsample of SynthText dataset and save in cache
URL = 'https://github.com/mindee/doctr/releases/download/v0.4.1/synthtext_samples-89fd1445.zip'
FILE_HASH = '89fd1445457b9ad8391e17620c6ae1b45134be2bf5449f36e7e4275176cc16ac'
FILE_NAME = 'SynthText.zip'
download_from_url(URL, FILE_NAME, FILE_HASH, cache_subdir='datasets')
else:
download = True
if dataset_name.lower() == "synthtext":
# Monkeypatch the class to download a subsample
datasets.__dict__[
dataset_name
].URL = 'https://github.com/mindee/doctr/releases/download/v0.4.1/synthtext_samples-89fd1445.zip'
datasets.__dict__[dataset_name].SHA256 = '89fd1445457b9ad8391e17620c6ae1b45134be2bf5449f36e7e4275176cc16ac'

ds = datasets.__dict__[dataset_name](
train=train, download=download, sample_transforms=Resize(input_size), rotated_bbox=rotate,
train=train, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
)

assert len(ds) == size
Expand Down

0 comments on commit e719b33

Please sign in to comment.