Skip to content

Commit

Permalink
yews.transform under cover with 100% coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
lijunzh committed Apr 17, 2019
1 parent 2cf6108 commit ce4b445
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 44 deletions.
37 changes: 37 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
import torch
import numpy as np


class DummpyBaseTransform(transforms.BaseTransform):

def __init__(self, a=0, b=1):
self.a = a
self.b = b

def __call__(self, data):
return data


class TestIsNumpyWaveform:

def test_single_channel_waveform_vector(self):
wav = np.empty(10)
assert F._is_numpy_waveform(wav)
Expand All @@ -26,7 +38,9 @@ def test_invalid_waveform_wrong_type(self):
wav = torch.tensor(10)
assert not F._is_numpy_waveform(wav)


class TestToTensor:

def test_type_exception(self):
wav = torch.tensor(10)
with pytest.raises(TypeError):
Expand All @@ -42,7 +56,19 @@ def test_multi_channel_waveform(self):
tensor = torch.zeros(3, 10,dtype=torch.float)
assert torch.allclose(F._to_tensor(wav), tensor)

class TestBaseTransform:

def test_raise_call_notimplementederror(self):
with pytest.raises(NotImplementedError):
t = transforms.BaseTransform()
t(0)

def test_repr(self):
t = transforms.BaseTransform()
assert type(t.__repr__()) is str

class TestMandatoryMethods:

def test_call_method(self):
assert all([hasattr(getattr(transforms, t), '__call__') for t in
transforms.transforms.__all__])
Expand All @@ -51,6 +77,13 @@ def test_repr_method(self):
assert all([hasattr(getattr(transforms, t), '__repr__') for t in
transforms.transforms.__all__])


class TestComposeTransform:

def test_repr(self):
t = transforms.Compose([DummpyBaseTransform()])
assert type(t.__repr__()) is str

class TestTransformCorrectness:
def test_compose(self):
wav = np.array([1, 3])
Expand Down Expand Up @@ -80,3 +113,7 @@ def test_cut_waveform_value(self):
assert np.allclose(transforms.CutWaveform(100, 1900)(wav),
wav[:, 100:1900])

def test_soft_clip(self):
wav = np.array([-1, -0.5, 0, 0.5, 1])
assert np.allclose(transforms.SoftClip()(wav),
np.array([0.26894142, 0.37754067, 0.5, 0.62245933, 0.73105858]))
1 change: 1 addition & 0 deletions yews/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .base import BaseTransform, Compose
from .transforms import *
55 changes: 55 additions & 0 deletions yews/transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
class BaseTransform(object):
"""An abstract class representing a Transform.
All other transform should subclass it. All subclasses should override
``__call__`` which performs the transform.
Args:
root (object): Source of the dataset.
sample_transform (callable, optional): A function/transform that takes
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
a target and transform it.
Attributes:
samples (dataset-like object): Dataset-like object for samples.
targets (dataset-like object): Dataset-like object for targets.
"""

def __call__(self, data):
raise NotImplementedError

def __repr__(self):
head = self.__class__.__name__
content = [f"{key} = {val}" for key, val in self.__dict__.items()]
body = ", ".join(content)
return f"{head}({body})"


class Compose(BaseTransform):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, wav):
for t in self.transforms:
wav = t(wav)
return wav

def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
49 changes: 5 additions & 44 deletions yews/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,14 @@
from .base import BaseTransform
from . import functional as F

__all__ = [
"Compose",
"ToTensor",
"ZeroMean",
"SoftClip",
"CutWaveform",
]

class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, wav):
for t in self.transforms:
wav = t(wav)
return wav

def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string


class ToTensor(object):
class ToTensor(BaseTransform):
"""Convert a ``numpy.ndarray`` to tensor.
Converts a numpy.ndarray (C x S) to a torch.FloatTensor of shape (C x S).
Expand All @@ -51,11 +23,8 @@ def __call__(self, wav):
"""
return F._to_tensor(wav)

def __repr__(self):
return self.__class__.__name__ + '()'


class SoftClip(object):
class SoftClip(BaseTransform):
"""Soft clip input to compress large amplitude signals
"""
Expand All @@ -66,11 +35,8 @@ def __init__(self, scale=1):
def __call__(self, wav):
return F.expit(wav * self.scale)

def __repr__(self):
return self.__class__.__name__ + f'(scale = {self.scale})'


class ZeroMean(object):
class ZeroMean(BaseTransform):
"""Remove mean from each waveforms
"""
Expand All @@ -80,11 +46,8 @@ def __call__(self, wav):
wav -= wav.mean(axis=0)
return wav.T

def __repr__(self):
return self.__class__.__name__ + '()'


class CutWaveform(object):
class CutWaveform(BaseTransform):
"""Cut a portion of waveform.
"""
Expand All @@ -96,5 +59,3 @@ def __init__(self, samplestart, sampleend):
def __call__(self, wav):
return wav[:, self.start:self.end]

def __repr__(self):
return self.__call__.__name__ + f'(start = {self.start}, end = {self.end})'

0 comments on commit ce4b445

Please sign in to comment.