Skip to content

Commit

Permalink
add test to yews.datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lijunzh committed Apr 17, 2019
1 parent ce4b445 commit df8c897
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 1 deletion.
110 changes: 110 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import yews.datasets as datasets
import yews.transforms as transforms

from pathlib import Path


def test_is_dataset():
assert not datasets.is_dataset(0)
assert datasets.is_dataset([])


class DummpyDatasetlike(object):

def __init__(self, size=1):
self.size = size
self.data = ['a item'] * self.size

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.size


class DummyBaseDataset(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()


class DummyBaseDatasetNoSamples(datasets.BaseDataset):

def build_dataset(self):
return 0, DummpyDatasetlike()


class DummyBaseDatasetNoTargets(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(), 0


class DummyBaseDatasetWrongLength(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(1), DummpyDatasetlike(2)


class DummyTransform(transforms.BaseTransform):

def __call__(self, data):
return "transformed"


class DummyPathDataset(datasets.PathDataset):

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()


class TestBaseDataset:

def test_empty_construct(self):
dset = datasets.BaseDataset()
assert len(dset) == 0

def test_noempty_constrct(self):
dset = DummyBaseDataset(root='.')
assert len(dset) == 1

def test_raise_notimplmenetederror(self):
with pytest.raises(NotImplementedError):
dset = datasets.BaseDataset('.')

def test_no_samples(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetNoSamples(root='.')

def test_no_targets(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetNoTargets(root='.')

def test_samples_targets_not_match(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetWrongLength(root='.')

def test_getitem_with_transform(self):
dset = DummyBaseDataset(root='.',
sample_transform=DummyTransform(),
target_transform=DummyTransform())
assert dset[0] == ('transformed', 'transformed')
dset = DummyBaseDataset(root='.')
assert dset[0] == ('a item', 'a item')

def test_repr(self):
dset = DummyBaseDataset(root='.',
sample_transform='t',
target_transform='tt')
assert type(dset.__repr__()) is str
dset = DummyBaseDataset()
assert type(dset.__repr__()) is str


class TestPathDataset:

dset = DummyPathDataset(root='.')

def test_root_is_path(self):
assert self.dset.root == Path(self.dset.root).resolve()
2 changes: 1 addition & 1 deletion yews/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __getitem__(self, index):
sample = self.sample_transform(sample)

if self.target_transform is not None:
target = transform_transform(target)
target = self.target_transform(target)

return sample, target

Expand Down

0 comments on commit df8c897

Please sign in to comment.