Skip to content

Commit

Permalink
Refactor test utilities (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jul 1, 2020
1 parent 6b15905 commit a20da5e
Show file tree
Hide file tree
Showing 17 changed files with 421 additions and 371 deletions.
216 changes: 0 additions & 216 deletions test/common_utils.py

This file was deleted.

31 changes: 31 additions & 0 deletions test/common_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .data_utils import (
get_asset_path,
get_whitenoise,
get_sinusoid,
)
from .backend_utils import (
set_audio_backend,
BACKENDS,
BACKENDS_MP3,
)
from .test_case_utils import (
TempDirMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoSoxBackend,
)
from .wav_utils import (
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .parameterized_utils import (
load_params,
)
from . import sox_utils
41 changes: 41 additions & 0 deletions test/common_utils/backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest

import torchaudio

from .import data_utils


BACKENDS = torchaudio.list_audio_backends()


def _filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = data_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')

def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False

return [backend for backend in backends if supports_mp3(backend)]


BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS)


def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend

torchaudio.set_audio_backend(be)
78 changes: 78 additions & 0 deletions test/common_utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os.path
from typing import Union

import torch


_TEST_DIR_PATH = os.path.realpath(
os.path.join(os.path.dirname(__file__), '..'))


def get_asset_path(*paths):
"""Return full path of a test asset"""
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)


def get_whitenoise(
*,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
seed: int = 0,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with whitenoise
Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
seed: Seed value used for random number generation.
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
shape = [n_channels, sample_rate * duration]
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only folk on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn(shape, dtype=dtype, device='cpu')
tensor /= 2.0
tensor.clamp_(-1.0, 1.0)
return tensor.to(device=device)


def get_sinusoid(
*,
frequency: float = 300,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with sine wave.
Args:
frequency: Frequency of sine wave
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
pie2 = 2 * 3.141592653589793
end = pie2 * frequency * duration
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
return torch.sin(theta, out=None).repeat([n_channels, 1])
10 changes: 10 additions & 0 deletions test/common_utils/parameterized_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import json

from parameterized import param

from .data_utils import get_asset_path


def load_params(*paths):
with open(get_asset_path(*paths), 'r') as file:
return [param(json.loads(line)) for line in file]
File renamed without changes.
Loading

0 comments on commit a20da5e

Please sign in to comment.