Skip to content

Commit

Permalink
Refactorize yews.datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lijunzh committed Apr 17, 2019
1 parent 278caba commit 5a7b0e8
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 146 deletions.
10 changes: 8 additions & 2 deletions yews/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .classification import ClassificationDataset, DatasetArray, DatasetFolder
from .base import BaseDataset, PathDataset
from .file import FileDataset, DatasetArray
from .dir import DirDataset, DatasetFolder, DatasetArrayFolder

__all__ = (
'ClassificationDataset',
'BaseDataset',
'PathDataset',
'FileDataset',
'DirDataset',
'DatasetArray',
'DatasetFolder',
'DatasetArrayFolder',
)

108 changes: 108 additions & 0 deletions yews/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from torch.utils import data


class BaseDataset(data.Dataset):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``build_dataset`` which construct the dataset-like object from root.
A dataset-like object has both ``__len__`` and ``__getitem__`` implmented.
Typical dataset-like objects include python list and numpy ndarray.
Args:
root (object): Source of the dataset.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (dataset-like object): Dataset-like object for samples.
targets (dataset-like object): Dataset-like object for targets.
"""

_repr_indent = 4

def __init__(self, root=None, sample_transform=None, target_transform=None):
self.root = root

if self.root is not None:
self.samples, self.targets = self.build_dataset()

if len(samples) == len(targets):
self.size = len(targets)
else:
raise ValueError("Samples and targets have different lengths.")

self.sample_transform = sample_transform
self.target_transform = target_transform

def build_dataset(self):
"""
Returns:
samples (ndarray): List of samples.
labels (ndarray): List of labels.
"""
raise NotImplementedError

def __getitem__(self, index):
sample = self.samples[index]
target = self.targets[index]

if self.sample_transform is not None:
sample = self.sample_transform(sample)

if self.target_transform is not None:
target = transform_transform(target)

return sample, target

def __len__(self):
return self.size

def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if self.sample_transform is not None:
body += self._format_transform_repr(self.sample_transform,
"Sample transforms: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)

def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])

def extra_repr(self):
return ""


class PathDataset(BaseDataset):
"""An abstract class representing a Dataset defined by a Path.
Args:
root (object): Path to the dataset.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (list): List of samples in the dataset.
targets (list): List of targets in teh dataset.
"""

def __init__(self, **kwargs):
super(PathDataset, self).__init__(**kwargs)
self.root = Path(self.root).resolve()
144 changes: 0 additions & 144 deletions yews/datasets/classification.py

This file was deleted.

110 changes: 110 additions & 0 deletions yews/datasets/dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from .base import PathDataset


class DirDataset(PathDataset):
"""An abstract class representing a Dataset in a directory.
Args:
root (object): Directory of the dataset.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (list): List of samples in the dataset.
targets (list): List of targets in teh dataset.
"""

def __init__(self, **kwargs):
super(DirDataset, self).__init__(**kwargs)
if not self.root.is_dir():
raise ValueError(f"{self.root} is not a directory.")


class DatasetArrayFolder(DirDataset):
"""A generic data loader for a folder of ``.npy`` files where samples are
arranged in the following way: ::
root/samples.npy: each row is a sample
root/targets.npy: each row is a label
where both samples and targets can be arrays.
Args:
root (object): Path to the dataset.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (list): List of samples in the dataset.
targets (list): List of targets in teh dataset.
"""

def build_dataset(self):
samples = np.load(self.root / 'samples.npy', mmap_mode='r')
targets = np.load(self.root / 'targets.npy', mmap_mode='r')

return samples, targets


class DatasetFolder(DirDataset):
"""A generic data loader for a folder where samples are arranged in the
following way: ::
root/.../class_x.xxx
root/.../class_x.sdf3
root/.../class_x.asd932
root/.../class_y.yyy
root/.../class_y.as4h
root/.../blass_y.jlk2
Args:
root (path): Path to the dataset.
loader (callable): Function that load one sample from a file.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (list): List of samples in the dataset.
targets (list): List of targets in teh dataset.
"""

class FilesLoader(object):
"""A dataset-like class for loading a list of files given a loader.
Args:
files (list): List of file paths
loader (callable): Function that load one file.
"""

def __init__(self, files, loader):
self.files = files
self.loader = loader

def __getitem__(self, index):
return self.loader(self.file_list[index])

def __len__(self):
return len(file_list)

def __init__(self, loader, **kwargs):
super(DatasetFolder, self).__init__(**kwargs)
self.loader = loader

def make_dataset(self):
files = [p for p in self.root.glob("**/*") if p.is_file()]
labels = [p.name.split('.')[0] for p in files]
samples = self.FilesLoader(files, self.loader)

return samples, labels

Loading

0 comments on commit 5a7b0e8

Please sign in to comment.