diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 179206b..e43e075 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,7 +5,19 @@ import torch import numpy as np + +class DummpyBaseTransform(transforms.BaseTransform): + + def __init__(self, a=0, b=1): + self.a = a + self.b = b + + def __call__(self, data): + return data + + class TestIsNumpyWaveform: + def test_single_channel_waveform_vector(self): wav = np.empty(10) assert F._is_numpy_waveform(wav) @@ -26,7 +38,9 @@ def test_invalid_waveform_wrong_type(self): wav = torch.tensor(10) assert not F._is_numpy_waveform(wav) + class TestToTensor: + def test_type_exception(self): wav = torch.tensor(10) with pytest.raises(TypeError): @@ -42,7 +56,19 @@ def test_multi_channel_waveform(self): tensor = torch.zeros(3, 10,dtype=torch.float) assert torch.allclose(F._to_tensor(wav), tensor) +class TestBaseTransform: + + def test_raise_call_notimplementederror(self): + with pytest.raises(NotImplementedError): + t = transforms.BaseTransform() + t(0) + + def test_repr(self): + t = transforms.BaseTransform() + assert type(t.__repr__()) is str + class TestMandatoryMethods: + def test_call_method(self): assert all([hasattr(getattr(transforms, t), '__call__') for t in transforms.transforms.__all__]) @@ -51,6 +77,13 @@ def test_repr_method(self): assert all([hasattr(getattr(transforms, t), '__repr__') for t in transforms.transforms.__all__]) + +class TestComposeTransform: + + def test_repr(self): + t = transforms.Compose([DummpyBaseTransform()]) + assert type(t.__repr__()) is str + class TestTransformCorrectness: def test_compose(self): wav = np.array([1, 3]) @@ -80,3 +113,7 @@ def test_cut_waveform_value(self): assert np.allclose(transforms.CutWaveform(100, 1900)(wav), wav[:, 100:1900]) + def test_soft_clip(self): + wav = np.array([-1, -0.5, 0, 0.5, 1]) + assert np.allclose(transforms.SoftClip()(wav), + np.array([0.26894142, 0.37754067, 0.5, 0.62245933, 0.73105858])) diff --git a/yews/transforms/__init__.py b/yews/transforms/__init__.py index 7986cdd..118c9bf 100644 --- a/yews/transforms/__init__.py +++ b/yews/transforms/__init__.py @@ -1 +1,2 @@ +from .base import BaseTransform, Compose from .transforms import * diff --git a/yews/transforms/base.py b/yews/transforms/base.py new file mode 100644 index 0000000..a40ae78 --- /dev/null +++ b/yews/transforms/base.py @@ -0,0 +1,55 @@ +class BaseTransform(object): + """An abstract class representing a Transform. + + All other transform should subclass it. All subclasses should override + ``__call__`` which performs the transform. + + Args: + root (object): Source of the dataset. + sample_transform (callable, optional): A function/transform that takes + a sample and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + a target and transform it. + + Attributes: + samples (dataset-like object): Dataset-like object for samples. + targets (dataset-like object): Dataset-like object for targets. + + """ + + def __call__(self, data): + raise NotImplementedError + + def __repr__(self): + head = self.__class__.__name__ + content = [f"{key} = {val}" for key, val in self.__dict__.items()] + body = ", ".join(content) + return f"{head}({body})" + + +class Compose(BaseTransform): + """Composes several transforms together. + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, wav): + for t in self.transforms: + wav = t(wav) + return wav + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string diff --git a/yews/transforms/transforms.py b/yews/transforms/transforms.py index 43f36f9..0a2ddf2 100644 --- a/yews/transforms/transforms.py +++ b/yews/transforms/transforms.py @@ -1,42 +1,14 @@ +from .base import BaseTransform from . import functional as F __all__ = [ - "Compose", "ToTensor", "ZeroMean", "SoftClip", "CutWaveform", ] -class Compose(object): - """Composes several transforms together. - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - Example: - >>> transforms.Compose([ - >>> transforms.CenterCrop(10), - >>> transforms.ToTensor(), - >>> ]) - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, wav): - for t in self.transforms: - wav = t(wav) - return wav - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class ToTensor(object): +class ToTensor(BaseTransform): """Convert a ``numpy.ndarray`` to tensor. Converts a numpy.ndarray (C x S) to a torch.FloatTensor of shape (C x S). @@ -51,11 +23,8 @@ def __call__(self, wav): """ return F._to_tensor(wav) - def __repr__(self): - return self.__class__.__name__ + '()' - -class SoftClip(object): +class SoftClip(BaseTransform): """Soft clip input to compress large amplitude signals """ @@ -66,11 +35,8 @@ def __init__(self, scale=1): def __call__(self, wav): return F.expit(wav * self.scale) - def __repr__(self): - return self.__class__.__name__ + f'(scale = {self.scale})' - -class ZeroMean(object): +class ZeroMean(BaseTransform): """Remove mean from each waveforms """ @@ -80,11 +46,8 @@ def __call__(self, wav): wav -= wav.mean(axis=0) return wav.T - def __repr__(self): - return self.__class__.__name__ + '()' - -class CutWaveform(object): +class CutWaveform(BaseTransform): """Cut a portion of waveform. """ @@ -96,5 +59,3 @@ def __init__(self, samplestart, sampleend): def __call__(self, wav): return wav[:, self.start:self.end] - def __repr__(self): - return self.__call__.__name__ + f'(start = {self.start}, end = {self.end})'