diff --git a/yews/datasets/__init__.py b/yews/datasets/__init__.py index 3e76456..167cfe7 100644 --- a/yews/datasets/__init__.py +++ b/yews/datasets/__init__.py @@ -1,8 +1,14 @@ -from .classification import ClassificationDataset, DatasetArray, DatasetFolder +from .base import BaseDataset, PathDataset +from .file import FileDataset, DatasetArray +from .dir import DirDataset, DatasetFolder, DatasetArrayFolder __all__ = ( - 'ClassificationDataset', + 'BaseDataset', + 'PathDataset', + 'FileDataset', + 'DirDataset', 'DatasetArray', 'DatasetFolder', + 'DatasetArrayFolder', ) diff --git a/yews/datasets/base.py b/yews/datasets/base.py new file mode 100644 index 0000000..b9711d4 --- /dev/null +++ b/yews/datasets/base.py @@ -0,0 +1,108 @@ +from torch.utils import data + + +class BaseDataset(data.Dataset): + """An abstract class representing a Dataset. + + All other datasets should subclass it. All subclasses should override + ``build_dataset`` which construct the dataset-like object from root. + + A dataset-like object has both ``__len__`` and ``__getitem__`` implmented. + Typical dataset-like objects include python list and numpy ndarray. + + Args: + root (object): Source of the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (dataset-like object): Dataset-like object for samples. + targets (dataset-like object): Dataset-like object for targets. + + """ + + _repr_indent = 4 + + def __init__(self, root=None, sample_transform=None, target_transform=None): + self.root = root + + if self.root is not None: + self.samples, self.targets = self.build_dataset() + + if len(samples) == len(targets): + self.size = len(targets) + else: + raise ValueError("Samples and targets have different lengths.") + + self.sample_transform = sample_transform + self.target_transform = target_transform + + def build_dataset(self): + """ + Returns: + samples (ndarray): List of samples. + labels (ndarray): List of labels. + + """ + raise NotImplementedError + + def __getitem__(self, index): + sample = self.samples[index] + target = self.targets[index] + + if self.sample_transform is not None: + sample = self.sample_transform(sample) + + if self.target_transform is not None: + target = transform_transform(target) + + return sample, target + + def __len__(self): + return self.size + + def __repr__(self): + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.root is not None: + body.append("Root location: {}".format(self.root)) + body += self.extra_repr().splitlines() + if self.sample_transform is not None: + body += self._format_transform_repr(self.sample_transform, + "Sample transforms: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transforms: ") + lines = [head] + [" " * self._repr_indent + line for line in body] + return '\n'.join(lines) + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def extra_repr(self): + return "" + + +class PathDataset(BaseDataset): + """An abstract class representing a Dataset defined by a Path. + + Args: + root (object): Path to the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + """ + + def __init__(self, **kwargs): + super(PathDataset, self).__init__(**kwargs) + self.root = Path(self.root).resolve() diff --git a/yews/datasets/classification.py b/yews/datasets/classification.py deleted file mode 100644 index 05517b4..0000000 --- a/yews/datasets/classification.py +++ /dev/null @@ -1,144 +0,0 @@ -from pathlib import Path -from torch.utils.data import Dataset - -class ClassificationDataset(Dataset): - """Basic dataset class for classification task. - - """ - - _repr_indent = 4 - - def __init__(self, source=None, transform=None, target_transform=None): - self.source = source - self.transform = transform - self.target_transform = target_transform - - self.classes = self._find_classes() - self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} - self.samples = self._make_dataset() - self.targets = [self.class_to_idx[s[1]] for s in self.samples] - - self.transform = transform - self.target_transform = target_transform - - def _find_classes(self): - raise NotImplementedError - - def _make_dataset(self): - raise NotImplementedError - - def __len__(self): - return len(self.targets) - - def __repr__(self): - head = f"Dataset {self.__class__.__name__}" - body = [f"Number of datapoints: {self.__len__()}"] - if type(self.source) is str: - body.append(f"Source location: {self.source}") - else: - body.append(f"Source location: array") - if hasattr(self, 'transform') and self.transform is not None: - body.append(self._foramt_transform_repr(self.transform, "Transforms: ")) - if hasattr(self, 'targe_transform') and self.target_transform is not None: - body.append(self._foramt_transform_repr(self.target_transform, "Target Transforms: ")) - - lines = [head] + [" " * self._repr_indent + line for line in body] - return '\n'.join(lines) - - def _foramt_transform_repr(self, transform, head): - lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) - - -class DatasetArray(ClassificationDataset): - """A generic dataloader for classification task where samples and targets - are stored in numpy arrays. - - Args: - samples (ndarray): Numpy array of seismic data where each row is a - sample. The first column is a single/multi-component waveform. The - second column is the target. The rest columns are additional info - about the dataset. - transform (callable, optional): A function/transform that takes in - a sample and returns a transformed version. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. - - Attributes: - classes (list): List of the class names. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample, class_index) tuples. - targets (list): The class_index value for each sample in the dataset. - """ - - def _find_classes(self): - classes = list(set(self.source[:, 1])) - return classes - - def _make_dataset(self): - return [(s[0], s[1]) for s in self.source] - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - sample = self.samples[index][0] - target = self.targets[index] - - if self.transform is not None: - sample = self.transform(sample) - - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - -class DatasetFolder(ClassificationDataset): - """A generic dataloader for classification task where the samples are - arranaged in a folder: - - Args: - root (string): Root directory path. - loader (callable): A function to load a sample given its path. - transform (callable, optional): A function/transform that takes in - a sample and returns a transformed version. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. - - Attributes: - classes (list): List of the class names. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample path, class_index) tuples - targets (list): The class_index value for each image in the dataset - """ - - def __init__(self, root, loader, transform=None, target_transform=None): - super(ClassificationDatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) - self.loader = loader - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - - path, target = self.samples[index] - sample = self.loader(path) - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - diff --git a/yews/datasets/dir.py b/yews/datasets/dir.py new file mode 100644 index 0000000..ce5f1e7 --- /dev/null +++ b/yews/datasets/dir.py @@ -0,0 +1,110 @@ +from .base import PathDataset + + +class DirDataset(PathDataset): + """An abstract class representing a Dataset in a directory. + + Args: + root (object): Directory of the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + """ + + def __init__(self, **kwargs): + super(DirDataset, self).__init__(**kwargs) + if not self.root.is_dir(): + raise ValueError(f"{self.root} is not a directory.") + + +class DatasetArrayFolder(DirDataset): + """A generic data loader for a folder of ``.npy`` files where samples are + arranged in the following way: :: + + root/samples.npy: each row is a sample + root/targets.npy: each row is a label + + where both samples and targets can be arrays. + + Args: + root (object): Path to the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + """ + + def build_dataset(self): + samples = np.load(self.root / 'samples.npy', mmap_mode='r') + targets = np.load(self.root / 'targets.npy', mmap_mode='r') + + return samples, targets + + +class DatasetFolder(DirDataset): + """A generic data loader for a folder where samples are arranged in the + following way: :: + + root/.../class_x.xxx + root/.../class_x.sdf3 + root/.../class_x.asd932 + + root/.../class_y.yyy + root/.../class_y.as4h + root/.../blass_y.jlk2 + + Args: + root (path): Path to the dataset. + loader (callable): Function that load one sample from a file. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + + """ + + class FilesLoader(object): + """A dataset-like class for loading a list of files given a loader. + + Args: + files (list): List of file paths + loader (callable): Function that load one file. + """ + + def __init__(self, files, loader): + self.files = files + self.loader = loader + + def __getitem__(self, index): + return self.loader(self.file_list[index]) + + def __len__(self): + return len(file_list) + + def __init__(self, loader, **kwargs): + super(DatasetFolder, self).__init__(**kwargs) + self.loader = loader + + def make_dataset(self): + files = [p for p in self.root.glob("**/*") if p.is_file()] + labels = [p.name.split('.')[0] for p in files] + samples = self.FilesLoader(files, self.loader) + + return samples, labels + diff --git a/yews/datasets/file.py b/yews/datasets/file.py new file mode 100644 index 0000000..7801643 --- /dev/null +++ b/yews/datasets/file.py @@ -0,0 +1,58 @@ +from pathlib import Path +import numpy as np + +from .base import PathDataset + + +class FileDataset(PathDataset): + """An abstract class representing a Dataset in a file. + + Args: + root (object): File of the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + """ + + def __init__(self, **kwargs): + super(FileDataset, self).__init__(**kwargs) + if not self.root.is_file(): + raise ValueError(f"{self.root} is not a file.") + + +class DatasetArray(FileDataset): + """A generic data loader for ``.npy`` file where samples are arranged is + the following way: :: + + array = [ + [sample0, target0], + [sample1, target1], + ... + ] + + where both samples and targets can be arrays. + + Args: + root (object): Path to the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (list): List of samples in the dataset. + targets (list): List of targets in teh dataset. + + """ + + def build_dataset(self): + data = np.load(self.root) + return data[:, 0], data[:, 1] + +