Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor test utilities #756

Merged
merged 3 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 0 additions & 216 deletions test/common_utils.py

This file was deleted.

31 changes: 31 additions & 0 deletions test/common_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .data_utils import (
get_asset_path,
get_whitenoise,
get_sinusoid,
)
from .backend_utils import (
set_audio_backend,
BACKENDS,
BACKENDS_MP3,
)
from .test_case_utils import (
TempDirMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoSoxBackend,
)
from .wav_utils import (
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .parameterized_utils import (
load_params,
)
from . import sox_utils
41 changes: 41 additions & 0 deletions test/common_utils/backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest

import torchaudio

from .import data_utils


BACKENDS = torchaudio.list_audio_backends()


def _filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = data_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')

def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False

return [backend for backend in backends if supports_mp3(backend)]


BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS)


def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend

torchaudio.set_audio_backend(be)
78 changes: 78 additions & 0 deletions test/common_utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os.path
from typing import Union

import torch


_TEST_DIR_PATH = os.path.realpath(
os.path.join(os.path.dirname(__file__), '..'))


def get_asset_path(*paths):
"""Return full path of a test asset"""
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)


def get_whitenoise(
*,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
seed: int = 0,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with whitenoise

Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
seed: Seed value used for random number generation.
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
shape = [n_channels, sample_rate * duration]
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only folk on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn(shape, dtype=dtype, device='cpu')
tensor /= 2.0
tensor.clamp_(-1.0, 1.0)
return tensor.to(device=device)


def get_sinusoid(
*,
frequency: float = 300,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with sine wave.

Args:
frequency: Frequency of sine wave
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
dtype: Torch dtype
device: device

Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
pie2 = 2 * 3.141592653589793
end = pie2 * frequency * duration
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
return torch.sin(theta, out=None).repeat([n_channels, 1])
10 changes: 10 additions & 0 deletions test/common_utils/parameterized_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import json

from parameterized import param

from .data_utils import get_asset_path


def load_params(*paths):
with open(get_asset_path(*paths), 'r') as file:
return [param(json.loads(line)) for line in file]
File renamed without changes.
Loading