Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Mar 11, 2022
1 parent 2826794 commit 063b9a6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from doctr.transforms import Resize


def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True):
def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True):

# Fetch one sample
img, target = ds[0]
assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
assert img.dtype == torch.float32
assert isinstance(target, dict)
if has_boxes:
if target_includes_boxes:
assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32
if is_polygons:
assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2)
Expand Down Expand Up @@ -446,4 +446,4 @@ def test_mjsynth_dataset(mock_mjsynth_dataset):

assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples
assert repr(ds) == f"MJSynth(train={True})"
_validate_dataset(ds, input_size, has_boxes=False)
_validate_dataset(ds, input_size, target_includes_boxes=False)
6 changes: 3 additions & 3 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from doctr.transforms import Resize


def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True):
def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True):

# Fetch one sample
img, target = ds[0]
assert isinstance(img, tf.Tensor)
assert img.shape == (*input_size, 3)
assert img.dtype == tf.float32
assert isinstance(target, dict)
if has_boxes:
if target_includes_boxes:
assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32
if is_polygons:
assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2)
Expand Down Expand Up @@ -434,4 +434,4 @@ def test_mjsynth_dataset(mock_mjsynth_dataset):

assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples
assert repr(ds) == f"MJSynth(train={True})"
_validate_dataset(ds, input_size, has_boxes=False)
_validate_dataset(ds, input_size, target_includes_boxes=False)

0 comments on commit 063b9a6

Please sign in to comment.