-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
284 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
from .classification import ClassificationDataset, DatasetArray, DatasetFolder | ||
from .base import BaseDataset, PathDataset | ||
from .file import FileDataset, DatasetArray | ||
from .dir import DirDataset, DatasetFolder, DatasetArrayFolder | ||
|
||
__all__ = ( | ||
'ClassificationDataset', | ||
'BaseDataset', | ||
'PathDataset', | ||
'FileDataset', | ||
'DirDataset', | ||
'DatasetArray', | ||
'DatasetFolder', | ||
'DatasetArrayFolder', | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from torch.utils import data | ||
|
||
|
||
class BaseDataset(data.Dataset): | ||
"""An abstract class representing a Dataset. | ||
All other datasets should subclass it. All subclasses should override | ||
``build_dataset`` which construct the dataset-like object from root. | ||
A dataset-like object has both ``__len__`` and ``__getitem__`` implmented. | ||
Typical dataset-like objects include python list and numpy ndarray. | ||
Args: | ||
root (object): Source of the dataset. | ||
sample_transform (callable, optional): A function/transform that takes | ||
a sample and returns a transformed version. | ||
target_transform (callable, optional): A function/transform that takes | ||
a target and transform it. | ||
Attributes: | ||
samples (dataset-like object): Dataset-like object for samples. | ||
targets (dataset-like object): Dataset-like object for targets. | ||
""" | ||
|
||
_repr_indent = 4 | ||
|
||
def __init__(self, root=None, sample_transform=None, target_transform=None): | ||
self.root = root | ||
|
||
if self.root is not None: | ||
self.samples, self.targets = self.build_dataset() | ||
|
||
if len(samples) == len(targets): | ||
self.size = len(targets) | ||
else: | ||
raise ValueError("Samples and targets have different lengths.") | ||
|
||
self.sample_transform = sample_transform | ||
self.target_transform = target_transform | ||
|
||
def build_dataset(self): | ||
""" | ||
Returns: | ||
samples (ndarray): List of samples. | ||
labels (ndarray): List of labels. | ||
""" | ||
raise NotImplementedError | ||
|
||
def __getitem__(self, index): | ||
sample = self.samples[index] | ||
target = self.targets[index] | ||
|
||
if self.sample_transform is not None: | ||
sample = self.sample_transform(sample) | ||
|
||
if self.target_transform is not None: | ||
target = transform_transform(target) | ||
|
||
return sample, target | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __repr__(self): | ||
head = "Dataset " + self.__class__.__name__ | ||
body = ["Number of datapoints: {}".format(self.__len__())] | ||
if self.root is not None: | ||
body.append("Root location: {}".format(self.root)) | ||
body += self.extra_repr().splitlines() | ||
if self.sample_transform is not None: | ||
body += self._format_transform_repr(self.sample_transform, | ||
"Sample transforms: ") | ||
if self.target_transform is not None: | ||
body += self._format_transform_repr(self.target_transform, | ||
"Target transforms: ") | ||
lines = [head] + [" " * self._repr_indent + line for line in body] | ||
return '\n'.join(lines) | ||
|
||
def _format_transform_repr(self, transform, head): | ||
lines = transform.__repr__().splitlines() | ||
return (["{}{}".format(head, lines[0])] + | ||
["{}{}".format(" " * len(head), line) for line in lines[1:]]) | ||
|
||
def extra_repr(self): | ||
return "" | ||
|
||
|
||
class PathDataset(BaseDataset): | ||
"""An abstract class representing a Dataset defined by a Path. | ||
Args: | ||
root (object): Path to the dataset. | ||
sample_transform (callable, optional): A function/transform that takes | ||
a sample and returns a transformed version. | ||
target_transform (callable, optional): A function/transform that takes | ||
a target and transform it. | ||
Attributes: | ||
samples (list): List of samples in the dataset. | ||
targets (list): List of targets in teh dataset. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(PathDataset, self).__init__(**kwargs) | ||
self.root = Path(self.root).resolve() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from .base import PathDataset | ||
|
||
|
||
class DirDataset(PathDataset): | ||
"""An abstract class representing a Dataset in a directory. | ||
Args: | ||
root (object): Directory of the dataset. | ||
sample_transform (callable, optional): A function/transform that takes | ||
a sample and returns a transformed version. | ||
target_transform (callable, optional): A function/transform that takes | ||
a target and transform it. | ||
Attributes: | ||
samples (list): List of samples in the dataset. | ||
targets (list): List of targets in teh dataset. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(DirDataset, self).__init__(**kwargs) | ||
if not self.root.is_dir(): | ||
raise ValueError(f"{self.root} is not a directory.") | ||
|
||
|
||
class DatasetArrayFolder(DirDataset): | ||
"""A generic data loader for a folder of ``.npy`` files where samples are | ||
arranged in the following way: :: | ||
root/samples.npy: each row is a sample | ||
root/targets.npy: each row is a label | ||
where both samples and targets can be arrays. | ||
Args: | ||
root (object): Path to the dataset. | ||
sample_transform (callable, optional): A function/transform that takes | ||
a sample and returns a transformed version. | ||
target_transform (callable, optional): A function/transform that takes | ||
a target and transform it. | ||
Attributes: | ||
samples (list): List of samples in the dataset. | ||
targets (list): List of targets in teh dataset. | ||
""" | ||
|
||
def build_dataset(self): | ||
samples = np.load(self.root / 'samples.npy', mmap_mode='r') | ||
targets = np.load(self.root / 'targets.npy', mmap_mode='r') | ||
|
||
return samples, targets | ||
|
||
|
||
class DatasetFolder(DirDataset): | ||
"""A generic data loader for a folder where samples are arranged in the | ||
following way: :: | ||
root/.../class_x.xxx | ||
root/.../class_x.sdf3 | ||
root/.../class_x.asd932 | ||
root/.../class_y.yyy | ||
root/.../class_y.as4h | ||
root/.../blass_y.jlk2 | ||
Args: | ||
root (path): Path to the dataset. | ||
loader (callable): Function that load one sample from a file. | ||
sample_transform (callable, optional): A function/transform that takes | ||
a sample and returns a transformed version. | ||
target_transform (callable, optional): A function/transform that takes | ||
a target and transform it. | ||
Attributes: | ||
samples (list): List of samples in the dataset. | ||
targets (list): List of targets in teh dataset. | ||
""" | ||
|
||
class FilesLoader(object): | ||
"""A dataset-like class for loading a list of files given a loader. | ||
Args: | ||
files (list): List of file paths | ||
loader (callable): Function that load one file. | ||
""" | ||
|
||
def __init__(self, files, loader): | ||
self.files = files | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
return self.loader(self.file_list[index]) | ||
|
||
def __len__(self): | ||
return len(file_list) | ||
|
||
def __init__(self, loader, **kwargs): | ||
super(DatasetFolder, self).__init__(**kwargs) | ||
self.loader = loader | ||
|
||
def make_dataset(self): | ||
files = [p for p in self.root.glob("**/*") if p.is_file()] | ||
labels = [p.name.split('.')[0] for p in files] | ||
samples = self.FilesLoader(files, self.loader) | ||
|
||
return samples, labels | ||
|
Oops, something went wrong.