This repository has been archived by the owner on Jan 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from PhoenixDL/loading
WIP: Loading
- Loading branch information
Showing
37 changed files
with
2,060 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ coverage: | |
ignore: | ||
- "tests/" | ||
- "*/__init.py" | ||
- "_version.py" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
numpy | ||
torch | ||
threadpoolctl | ||
pandas | ||
sklearn | ||
tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from rising.loading.collate import numpy_collate | ||
from rising.loading.dataset import Dataset | ||
from rising.loading.loader import DataLoader | ||
from rising.loading.debug_mode import get_debug_mode, set_debug_mode, switch_debug_mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
import torch | ||
import collections.abc | ||
from typing import Any | ||
|
||
|
||
default_collate_err_msg_format = ( | ||
"default_collate: batch must contain tensors, numpy arrays, numbers, " | ||
"dicts or lists; found {}") | ||
|
||
|
||
def numpy_collate(batch: Any) -> Any: | ||
""" | ||
function to collate the samples to a whole batch of numpy arrays. | ||
PyTorch Tensors, scalar values and sequences will be casted to arrays | ||
automatically. | ||
Parameters | ||
---------- | ||
batch : Any | ||
a batch of samples. In most cases this is either a sequence, | ||
a mapping or a mixture of them | ||
Returns | ||
------- | ||
Any | ||
collated batch with optionally converted type (to numpy array) | ||
""" | ||
elem = batch[0] | ||
if isinstance(elem, np.ndarray): | ||
return np.stack(batch, 0) | ||
elif isinstance(elem, torch.Tensor): | ||
return numpy_collate([b.detach().cpu().numpy() for b in batch]) | ||
elif isinstance(elem, float) or isinstance(elem, int): | ||
return np.array(batch) | ||
elif isinstance(elem, str): | ||
return batch | ||
elif isinstance(elem, collections.abc.Mapping): | ||
return {key: numpy_collate([d[key] for d in batch]) for key in elem} | ||
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | ||
return type(elem)(*(numpy_collate(samples) for samples in zip(*batch))) | ||
elif isinstance(elem, collections.abc.Sequence): | ||
transposed = zip(*batch) | ||
return [numpy_collate(samples) for samples in transposed] | ||
|
||
raise TypeError(default_collate_err_msg_format.format(type(elem))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
from __future__ import annotations | ||
|
||
import pandas as pd | ||
import typing | ||
import pathlib | ||
from collections import defaultdict | ||
|
||
from rising.loading.dataset import Dataset | ||
from rising.loading.splitter import SplitType | ||
|
||
|
||
class DataContainer: | ||
def __init__(self, dataset: Dataset): | ||
""" | ||
Handles the splitting of datasets from different sources | ||
Parameters | ||
---------- | ||
dataset : dataset | ||
the dataset to split | ||
kwargs | ||
""" | ||
self._dataset = dataset | ||
self._dset = {} | ||
self._fold = None | ||
super().__init__() | ||
|
||
def split_by_index(self, split: SplitType) -> None: | ||
""" | ||
Splits dataset by a given split-dict | ||
Parameters | ||
---------- | ||
split : dict | ||
a dictionary containing tuples of strings and lists of indices | ||
for each split | ||
""" | ||
for key, idx in split.items(): | ||
self._dset[key] = self._dataset.get_subset(idx) | ||
|
||
def kfold_by_index(self, splits: typing.Iterable[SplitType]): | ||
""" | ||
Produces kfold splits based on the given indices. | ||
Parameters | ||
---------- | ||
splits : list | ||
list containing split dicts for each fold | ||
Yields | ||
------ | ||
DataContainer | ||
the data container with updated dataset splits | ||
""" | ||
for fold, split in enumerate(splits): | ||
self.split_by_index(split) | ||
self._fold = fold | ||
yield self | ||
self._fold = None | ||
|
||
def split_by_csv(self, path: typing.Union[pathlib.Path, str], | ||
index_column: str, **kwargs) -> None: | ||
""" | ||
Splits a dataset by splits given in a CSV file | ||
Parameters | ||
---------- | ||
path : str, pathlib.Path | ||
the path to the csv file | ||
index_column : str | ||
the label of the index column | ||
**kwargs : | ||
additional keyword arguments (see :func:`pandas.read_csv` for | ||
details) | ||
""" | ||
df = pd.read_csv(path, **kwargs) | ||
df = df.set_index(index_column) | ||
col = list(df.columns) | ||
self.split_by_index(self._read_split_from_df(df, col[0])) | ||
|
||
def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], | ||
index_column: str, **kwargs) -> DataContainer: | ||
""" | ||
Produces kfold splits based on the given csv file. | ||
Parameters | ||
---------- | ||
path : str, pathlib.Path | ||
the path to the csv file | ||
index_column : str | ||
the label of the index column | ||
**kwargs : | ||
additional keyword arguments (see :func:`pandas.read_csv` for | ||
details) | ||
Yields | ||
------ | ||
DataContainer | ||
the data container with updated dataset splits | ||
""" | ||
df = pd.read_csv(path, **kwargs) | ||
df = df.set_index(index_column) | ||
folds = list(df.columns) | ||
splits = [self._read_split_from_df(df, fold) for fold in folds] | ||
yield from self.kfold_by_index((splits)) | ||
|
||
@staticmethod | ||
def _read_split_from_df(df: pd.DataFrame, col: str) -> SplitType: | ||
""" | ||
Helper function to read a split from a given data frame | ||
Parameters | ||
---------- | ||
df : pandas.DataFrame | ||
the dataframe containing the split | ||
col : str | ||
the column inside the data frame containing the split | ||
Returns | ||
------- | ||
dict | ||
a dictionary of lists. Contains a string-list-tuple per split | ||
""" | ||
split = defaultdict(list) | ||
for index, row in df[[col]].iterrows(): | ||
split[str(row[col])].append(index) | ||
return split | ||
|
||
@property | ||
def dset(self) -> Dataset: | ||
if not self._dset: | ||
raise AttributeError("No Split found.") | ||
else: | ||
return self._dset | ||
|
||
@property | ||
def fold(self) -> int: | ||
if self._fold is None: | ||
raise AttributeError( | ||
"Fold not specified. Call `kfold_by_index` first.") | ||
else: | ||
return self._fold | ||
|
||
|
||
class DataContainerID(DataContainer): | ||
""" | ||
Data Container Class for datasets with an ID | ||
""" | ||
|
||
def split_by_id(self, split: SplitType) -> None: | ||
""" | ||
Splits the internal dataset by the given splits | ||
Parameters | ||
---------- | ||
split : dict | ||
dictionary containing a string-list tuple per split | ||
""" | ||
split_idx = defaultdict(list) | ||
for key, _id in split.items(): | ||
for _i in _id: | ||
split_idx[key].append(self._dataset.get_index_by_id(_i)) | ||
return super().split_by_index(split_idx) | ||
|
||
def kfold_by_id( | ||
self, | ||
splits: typing.Iterable[SplitType]): | ||
""" | ||
Produces kfold splits by an ID | ||
Parameters | ||
---------- | ||
splits : list | ||
list of dicts each containing the splits for a separate fold | ||
Yields | ||
------ | ||
DataContaimnerID | ||
the data container with updated internal datasets | ||
""" | ||
for fold, split in enumerate(splits): | ||
self.split_by_id(split) | ||
self._fold = fold | ||
yield self | ||
self._fold = None | ||
|
||
def split_by_csv_id(self, path: typing.Union[pathlib.Path, str], | ||
id_column: str, **kwargs) -> None: | ||
""" | ||
Splits the internal dataset by a given id column in a given csv file | ||
Parameters | ||
---------- | ||
path : str or pathlib.Path | ||
the path to the csv file | ||
id_column : str | ||
the key of the id_column | ||
**kwargs : | ||
additionalm keyword arguments (see :func:`pandas.read_csv` for | ||
details) | ||
""" | ||
df = pd.read_csv(path, **kwargs) | ||
df = df.set_index(id_column) | ||
col = list(df.columns) | ||
return self.split_by_id(self._read_split_from_df(df, col[0])) | ||
|
||
def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], | ||
id_column: str, **kwargs): | ||
""" | ||
Produces kfold splits by an ID column of a given csv file | ||
Parameters | ||
---------- | ||
path : str or pathlib.Path | ||
the path to the csv file | ||
id_column : str | ||
the key of the id_column | ||
**kwargs : | ||
additionalm keyword arguments (see :func:`pandas.read_csv` for | ||
details) | ||
Yields | ||
------ | ||
DataContaimnerID | ||
the data container with updated internal datasets | ||
""" | ||
df = pd.read_csv(path, **kwargs) | ||
df = df.set_index(id_column) | ||
folds = list(df.columns) | ||
splits = [self._read_split_from_df(df, fold) for fold in folds] | ||
yield from self.kfold_by_id((splits)) | ||
|
||
def save_split_to_csv_id(self, | ||
path: typing.Union[pathlib.Path, str], | ||
id_key: str, | ||
split_column: str = 'split', | ||
**kwargs) -> None: | ||
""" | ||
Saves a split top a given csv id | ||
Parameters | ||
---------- | ||
path : str or pathlib.Path | ||
the path of the csv file | ||
id_key : str | ||
the id key inside the csv file | ||
split_column : str | ||
the name of the split_column inside the csv file | ||
**kwargs : | ||
additional keyword arguments (see :meth:`pd.DataFrame.to_csv` | ||
for details) | ||
""" | ||
split_dict = {str(id_key): [], str(split_column): []} | ||
for key, item in self._dset.items(): | ||
for sample in item: | ||
split_dict[str(id_key)].append(sample[id_key]) | ||
split_dict[str(split_column)].append(str(key)) | ||
pd.DataFrame(split_dict).to_csv(path, **kwargs) |
Oops, something went wrong.