diff --git a/test/common_utils.py b/test/common_utils.py deleted file mode 100644 index 953b1b8790..0000000000 --- a/test/common_utils.py +++ /dev/null @@ -1,216 +0,0 @@ -import os -import shutil -import tempfile -import unittest -from typing import Union -from shutil import copytree - -import torch -from torch.testing._internal.common_utils import TestCase as PytorchTestCase -import torchaudio -from torchaudio._internal.module_utils import is_module_available - -_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) -BACKENDS = torchaudio.list_audio_backends() - - -def get_asset_path(*paths): - """Return full path of a test asset""" - return os.path.join(_TEST_DIR_PATH, 'assets', *paths) - - -def create_temp_assets_dir(): - """ - Creates a temporary directory and moves all files from test/assets there. - Returns a Tuple[string, TemporaryDirectory] which is the folder path - and object. - """ - tmp_dir = tempfile.TemporaryDirectory() - copytree(os.path.join(_TEST_DIR_PATH, "assets"), - os.path.join(tmp_dir.name, "assets")) - return tmp_dir.name, tmp_dir - - -def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32): - """ Generates random tensors given a seed and size - https://en.wikipedia.org/wiki/Linear_congruential_generator - X_{n + 1} = (a * X_n + c) % m - Using Borland C/C++ values - - The tensor will have values between [0,1) - Inputs: - seed (int): an int - size (Tuple[int]): the size of the output tensor - a (int): the multiplier constant to the generator - c (int): the additive constant to the generator - m (int): the modulus constant to the generator - """ - num_elements = 1 - for s in size: - num_elements *= s - - arr = [(a * seed + c) % m] - for i in range(num_elements - 1): - arr.append((a * arr[i] + c) % m) - - return torch.tensor(arr).float().view(size) / m - - -def filter_backends_with_mp3(backends): - # Filter out backends that do not support mp3 - test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3') - - def supports_mp3(backend): - torchaudio.set_audio_backend(backend) - try: - torchaudio.load(test_filepath) - return True - except (RuntimeError, ImportError): - return False - - return [backend for backend in backends if supports_mp3(backend)] - - -BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS) - - -def set_audio_backend(backend): - """Allow additional backend value, 'default'""" - if backend == 'default': - if 'sox' in BACKENDS: - be = 'sox' - elif 'soundfile' in BACKENDS: - be = 'soundfile' - else: - raise unittest.SkipTest('No default backend available') - else: - be = backend - - torchaudio.set_audio_backend(be) - - -class TempDirMixin: - """Mixin to provide easy access to temp dir""" - temp_dir_ = None - base_temp_dir = None - temp_dir = None - - @classmethod - def setUpClass(cls): - super().setUpClass() - # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. - # this is handy for debugging. - key = 'TORCHAUDIO_TEST_TEMP_DIR' - if key in os.environ: - cls.base_temp_dir = os.environ[key] - else: - cls.temp_dir_ = tempfile.TemporaryDirectory() - cls.base_temp_dir = cls.temp_dir_.name - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory): - cls.temp_dir_.cleanup() - - def setUp(self): - self.temp_dir = os.path.join(self.base_temp_dir, self.id()) - - def get_temp_path(self, *paths): - path = os.path.join(self.temp_dir, *paths) - os.makedirs(os.path.dirname(path), exist_ok=True) - return path - - -class TestBaseMixin: - """Mixin to provide consistent way to define device/dtype/backend aware TestCase""" - dtype = None - device = None - backend = None - - def setUp(self): - super().setUp() - set_audio_backend(self.backend) - - -class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): - pass - - -def skipIfNoExec(cmd): - return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available') - - -def skipIfNoModule(module, display_name=None): - display_name = display_name or module - return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') - - -skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available') -skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') -skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension') - - -def get_whitenoise( - *, - sample_rate: int = 16000, - duration: float = 1, # seconds - n_channels: int = 1, - seed: int = 0, - dtype: Union[str, torch.dtype] = "float32", - device: Union[str, torch.device] = "cpu", -): - """Generate pseudo audio data with whitenoise - - Args: - sample_rate: Sampling rate - duration: Length of the resulting Tensor in seconds. - n_channels: Number of channels - seed: Seed value used for random number generation. - Note that this function does not modify global random generator state. - dtype: Torch dtype - device: device - Returns: - Tensor: shape of (n_channels, sample_rate * duration) - """ - if isinstance(dtype, str): - dtype = getattr(torch, dtype) - shape = [n_channels, sample_rate * duration] - # According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices, - # so we only folk on CPU, generate values and move the data to the given device - with torch.random.fork_rng([]): - torch.random.manual_seed(seed) - tensor = torch.randn(shape, dtype=dtype, device='cpu') - tensor /= 2.0 - tensor.clamp_(-1.0, 1.0) - return tensor.to(device=device) - - -def get_sinusoid( - *, - frequency: float = 300, - sample_rate: int = 16000, - duration: float = 1, # seconds - n_channels: int = 1, - dtype: Union[str, torch.dtype] = "float32", - device: Union[str, torch.device] = "cpu", -): - """Generate pseudo audio data with sine wave. - - Args: - frequency: Frequency of sine wave - sample_rate: Sampling rate - duration: Length of the resulting Tensor in seconds. - n_channels: Number of channels - dtype: Torch dtype - device: device - - Returns: - Tensor: shape of (n_channels, sample_rate * duration) - """ - if isinstance(dtype, str): - dtype = getattr(torch, dtype) - pie2 = 2 * 3.141592653589793 - end = pie2 * frequency * duration - theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device) - return torch.sin(theta, out=None).repeat([n_channels, 1]) diff --git a/test/common_utils/__init__.py b/test/common_utils/__init__.py new file mode 100644 index 0000000000..d67f0dea61 --- /dev/null +++ b/test/common_utils/__init__.py @@ -0,0 +1,31 @@ +from .data_utils import ( + get_asset_path, + get_whitenoise, + get_sinusoid, +) +from .backend_utils import ( + set_audio_backend, + BACKENDS, + BACKENDS_MP3, +) +from .test_case_utils import ( + TempDirMixin, + TestBaseMixin, + PytorchTestCase, + TorchaudioTestCase, + skipIfNoCuda, + skipIfNoExec, + skipIfNoModule, + skipIfNoExtension, + skipIfNoSoxBackend, +) +from .wav_utils import ( + get_wav_data, + normalize_wav, + load_wav, + save_wav, +) +from .parameterized_utils import ( + load_params, +) +from . import sox_utils diff --git a/test/common_utils/backend_utils.py b/test/common_utils/backend_utils.py new file mode 100644 index 0000000000..158fde87ed --- /dev/null +++ b/test/common_utils/backend_utils.py @@ -0,0 +1,41 @@ +import unittest + +import torchaudio + +from .import data_utils + + +BACKENDS = torchaudio.list_audio_backends() + + +def _filter_backends_with_mp3(backends): + # Filter out backends that do not support mp3 + test_filepath = data_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3') + + def supports_mp3(backend): + torchaudio.set_audio_backend(backend) + try: + torchaudio.load(test_filepath) + return True + except (RuntimeError, ImportError): + return False + + return [backend for backend in backends if supports_mp3(backend)] + + +BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS) + + +def set_audio_backend(backend): + """Allow additional backend value, 'default'""" + if backend == 'default': + if 'sox' in BACKENDS: + be = 'sox' + elif 'soundfile' in BACKENDS: + be = 'soundfile' + else: + raise unittest.SkipTest('No default backend available') + else: + be = backend + + torchaudio.set_audio_backend(be) diff --git a/test/common_utils/data_utils.py b/test/common_utils/data_utils.py new file mode 100644 index 0000000000..e3e3972c4c --- /dev/null +++ b/test/common_utils/data_utils.py @@ -0,0 +1,78 @@ +import os.path +from typing import Union + +import torch + + +_TEST_DIR_PATH = os.path.realpath( + os.path.join(os.path.dirname(__file__), '..')) + + +def get_asset_path(*paths): + """Return full path of a test asset""" + return os.path.join(_TEST_DIR_PATH, 'assets', *paths) + + +def get_whitenoise( + *, + sample_rate: int = 16000, + duration: float = 1, # seconds + n_channels: int = 1, + seed: int = 0, + dtype: Union[str, torch.dtype] = "float32", + device: Union[str, torch.device] = "cpu", +): + """Generate pseudo audio data with whitenoise + + Args: + sample_rate: Sampling rate + duration: Length of the resulting Tensor in seconds. + n_channels: Number of channels + seed: Seed value used for random number generation. + Note that this function does not modify global random generator state. + dtype: Torch dtype + device: device + Returns: + Tensor: shape of (n_channels, sample_rate * duration) + """ + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + shape = [n_channels, sample_rate * duration] + # According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices, + # so we only folk on CPU, generate values and move the data to the given device + with torch.random.fork_rng([]): + torch.random.manual_seed(seed) + tensor = torch.randn(shape, dtype=dtype, device='cpu') + tensor /= 2.0 + tensor.clamp_(-1.0, 1.0) + return tensor.to(device=device) + + +def get_sinusoid( + *, + frequency: float = 300, + sample_rate: int = 16000, + duration: float = 1, # seconds + n_channels: int = 1, + dtype: Union[str, torch.dtype] = "float32", + device: Union[str, torch.device] = "cpu", +): + """Generate pseudo audio data with sine wave. + + Args: + frequency: Frequency of sine wave + sample_rate: Sampling rate + duration: Length of the resulting Tensor in seconds. + n_channels: Number of channels + dtype: Torch dtype + device: device + + Returns: + Tensor: shape of (n_channels, sample_rate * duration) + """ + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + pie2 = 2 * 3.141592653589793 + end = pie2 * frequency * duration + theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device) + return torch.sin(theta, out=None).repeat([n_channels, 1]) diff --git a/test/common_utils/parameterized_utils.py b/test/common_utils/parameterized_utils.py new file mode 100644 index 0000000000..24404a6edd --- /dev/null +++ b/test/common_utils/parameterized_utils.py @@ -0,0 +1,10 @@ +import json + +from parameterized import param + +from .data_utils import get_asset_path + + +def load_params(*paths): + with open(get_asset_path(*paths), 'r') as file: + return [param(json.loads(line)) for line in file] diff --git a/test/sox_io_backend/sox_utils.py b/test/common_utils/sox_utils.py similarity index 100% rename from test/sox_io_backend/sox_utils.py rename to test/common_utils/sox_utils.py diff --git a/test/common_utils/test_case_utils.py b/test/common_utils/test_case_utils.py new file mode 100644 index 0000000000..f3b0c343a6 --- /dev/null +++ b/test/common_utils/test_case_utils.py @@ -0,0 +1,75 @@ +import shutil +import os.path +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import TestCase as PytorchTestCase +import torchaudio +from torchaudio._internal.module_utils import is_module_available + +from .backend_utils import set_audio_backend + + +class TempDirMixin: + """Mixin to provide easy access to temp dir""" + temp_dir_ = None + base_temp_dir = None + temp_dir = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = 'TORCHAUDIO_TEST_TEMP_DIR' + if key in os.environ: + cls.base_temp_dir = os.environ[key] + else: + cls.temp_dir_ = tempfile.TemporaryDirectory() + cls.base_temp_dir = cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory): + cls.temp_dir_.cleanup() + + def setUp(self): + super().setUp() + self.temp_dir = os.path.join(self.base_temp_dir, self.id()) + + def get_temp_path(self, *paths): + path = os.path.join(self.temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + +class TestBaseMixin: + """Mixin to provide consistent way to define device/dtype/backend aware TestCase""" + dtype = None + device = None + backend = None + + def setUp(self): + super().setUp() + set_audio_backend(self.backend) + + +class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): + pass + + +def skipIfNoExec(cmd): + return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available') + + +def skipIfNoModule(module, display_name=None): + display_name = display_name or module + return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') + + +skipIfNoSoxBackend = unittest.skipIf( + 'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available') +skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') +skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension') diff --git a/test/common_utils/wav_utils.py b/test/common_utils/wav_utils.py new file mode 100644 index 0000000000..bc122ec6cb --- /dev/null +++ b/test/common_utils/wav_utils.py @@ -0,0 +1,86 @@ +from typing import Optional + +import torch +import scipy.io.wavfile + + +def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32: + pass + elif tensor.dtype == torch.int32: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 2147483647. + tensor[tensor < 0] /= 2147483648. + elif tensor.dtype == torch.int16: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 32767. + tensor[tensor < 0] /= 32768. + elif tensor.dtype == torch.uint8: + tensor = tensor.to(torch.float32) - 128 + tensor[tensor > 0] /= 127. + tensor[tensor < 0] /= 128. + return tensor + + +def get_wav_data( + dtype: str, + num_channels: int, + *, + num_frames: Optional[int] = None, + normalize: bool = True, + channels_first: bool = True, +): + """Generate linear signal of the given dtype and num_channels + + Data range is + [-1.0, 1.0] for float32, + [-2147483648, 2147483647] for int32 + [-32768, 32767] for int16 + [0, 255] for uint8 + + num_frames allow to change the linear interpolation parameter. + Default values are 256 for uint8, else 1 << 16. + 1 << 16 as default is so that int16 value range is completely covered. + """ + dtype_ = getattr(torch, dtype) + + if num_frames is None: + if dtype == 'uint8': + num_frames = 256 + else: + num_frames = 1 << 16 + + if dtype == 'uint8': + base = torch.linspace(0, 255, num_frames, dtype=dtype_) + if dtype == 'float32': + base = torch.linspace(-1., 1., num_frames, dtype=dtype_) + if dtype == 'int32': + base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) + if dtype == 'int16': + base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) + data = base.repeat([num_channels, 1]) + if not channels_first: + data = data.transpose(1, 0) + if normalize: + data = normalize_wav(data) + return data + + +def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: + """Load wav file without torchaudio""" + sample_rate, data = scipy.io.wavfile.read(path) + data = torch.from_numpy(data.copy()) + if data.ndim == 1: + data = data.unsqueeze(1) + if normalize: + data = normalize_wav(data) + if channels_first: + data = data.transpose(1, 0) + return data, sample_rate + + +def save_wav(path, data, sample_rate, channels_first=True): + """Save wav file without torchaudio""" + if channels_first: + data = data.transpose(1, 0) + scipy.io.wavfile.write(path, sample_rate, data.numpy()) diff --git a/test/functional_cpu_test.py b/test/functional_cpu_test.py index 470d6ab770..817a645087 100644 --- a/test/functional_cpu_test.py +++ b/test/functional_cpu_test.py @@ -10,6 +10,31 @@ from .functional_impl import Lfilter +def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32): + """ Generates random tensors given a seed and size + https://en.wikipedia.org/wiki/Linear_congruential_generator + X_{n + 1} = (a * X_n + c) % m + Using Borland C/C++ values + + The tensor will have values between [0,1) + Inputs: + seed (int): an int + size (Tuple[int]): the size of the output tensor + a (int): the multiplier constant to the generator + c (int): the additive constant to the generator + m (int): the modulus constant to the generator + """ + num_elements = 1 + for s in size: + num_elements *= s + + arr = [(a * seed + c) % m] + for i in range(num_elements - 1): + arr.append((a * arr[i] + c) % m) + + return torch.tensor(arr).float().view(size) / m + + class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): dtype = torch.float32 device = torch.device('cpu') @@ -49,7 +74,7 @@ def _test_istft_is_inverse_of_stft(kwargs): for data_size in [(2, 20), (3, 15), (4, 10)]: for i in range(100): - sound = common_utils.random_float_tensor(i, data_size) + sound = random_float_tensor(i, data_size) stft = torch.stft(sound, **kwargs) estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) @@ -211,8 +236,8 @@ def test_istft_of_sine(self): def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8): for i in range(self.number_of_trials): - tensor1 = common_utils.random_float_tensor(i, data_size) - tensor2 = common_utils.random_float_tensor(i * 2, data_size) + tensor1 = random_float_tensor(i, data_size) + tensor2 = random_float_tensor(i * 2, data_size) a, b = torch.rand(2) istft1 = torchaudio.functional.istft(tensor1, **kwargs) istft2 = torchaudio.functional.istft(tensor2, **kwargs) diff --git a/test/kaldi_compatibility_impl.py b/test/kaldi_compatibility_impl.py index 0488724de9..8d39e34057 100644 --- a/test/kaldi_compatibility_impl.py +++ b/test/kaldi_compatibility_impl.py @@ -1,5 +1,4 @@ """Test suites for checking numerical compatibility against Kaldi""" -import json import subprocess import kaldi_io @@ -8,7 +7,8 @@ import torchaudio.compliance.kaldi from . import common_utils -from parameterized import parameterized, param +from .common_utils import load_params +from parameterized import parameterized def _convert_args(**kwargs): @@ -43,11 +43,6 @@ def _run_kaldi(command, input_type, input_value): return torch.from_numpy(result.copy()) # copy supresses some torch warning -def _load_params(path): - with open(path, 'r') as file: - return [param(json.loads(line)) for line in file] - - class Kaldi(common_utils.TestBaseMixin): backend = 'sox' @@ -71,7 +66,7 @@ def test_sliding_window_cmn(self): kaldi_result = _run_kaldi(command, 'ark', tensor) self.assert_equal(result, expected=kaldi_result) - @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json'))) + @parameterized.expand(load_params('kaldi_test_fbank_args.json')) @common_utils.skipIfNoExec('compute-fbank-feats') def test_fbank(self, kwargs): """fbank should be numerically compatible with compute-fbank-feats""" @@ -82,7 +77,7 @@ def test_fbank(self, kwargs): kaldi_result = _run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) - @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json'))) + @parameterized.expand(load_params('kaldi_test_spectrogram_args.json')) @common_utils.skipIfNoExec('compute-spectrogram-feats') def test_spectrogram(self, kwargs): """spectrogram should be numerically compatible with compute-spectrogram-feats""" @@ -93,7 +88,7 @@ def test_spectrogram(self, kwargs): kaldi_result = _run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) - @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json'))) + @parameterized.expand(load_params('kaldi_test_mfcc_args.json')) @common_utils.skipIfNoExec('compute-mfcc-feats') def test_mfcc(self, kwargs): """mfcc should be numerically compatible with compute-mfcc-feats""" diff --git a/test/sox_io_backend/common.py b/test/sox_io_backend/common.py index d477852e12..eb85937236 100644 --- a/test/sox_io_backend/common.py +++ b/test/sox_io_backend/common.py @@ -1,90 +1,2 @@ -from typing import Optional - -import torch -import scipy.io.wavfile - - -def get_test_name(func, _, params): - return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' - - -def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: - if tensor.dtype == torch.float32: - pass - elif tensor.dtype == torch.int32: - tensor = tensor.to(torch.float32) - tensor[tensor > 0] /= 2147483647. - tensor[tensor < 0] /= 2147483648. - elif tensor.dtype == torch.int16: - tensor = tensor.to(torch.float32) - tensor[tensor > 0] /= 32767. - tensor[tensor < 0] /= 32768. - elif tensor.dtype == torch.uint8: - tensor = tensor.to(torch.float32) - 128 - tensor[tensor > 0] /= 127. - tensor[tensor < 0] /= 128. - return tensor - - -def get_wav_data( - dtype: str, - num_channels: int, - *, - num_frames: Optional[int] = None, - normalize: bool = True, - channels_first: bool = True, -): - """Generate linear signal of the given dtype and num_channels - - Data range is - [-1.0, 1.0] for float32, - [-2147483648, 2147483647] for int32 - [-32768, 32767] for int16 - [0, 255] for uint8 - - num_frames allow to change the linear interpolation parameter. - Default values are 256 for uint8, else 1 << 16. - 1 << 16 as default is so that int16 value range is completely covered. - """ - dtype_ = getattr(torch, dtype) - - if num_frames is None: - if dtype == 'uint8': - num_frames = 256 - else: - num_frames = 1 << 16 - - if dtype == 'uint8': - base = torch.linspace(0, 255, num_frames, dtype=dtype_) - if dtype == 'float32': - base = torch.linspace(-1., 1., num_frames, dtype=dtype_) - if dtype == 'int32': - base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) - if dtype == 'int16': - base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) - data = base.repeat([num_channels, 1]) - if not channels_first: - data = data.transpose(1, 0) - if normalize: - data = normalize_wav(data) - return data - - -def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: - """Load wav file without torchaudio""" - sample_rate, data = scipy.io.wavfile.read(path) - data = torch.from_numpy(data.copy()) - if data.ndim == 1: - data = data.unsqueeze(1) - if normalize: - data = normalize_wav(data) - if channels_first: - data = data.transpose(1, 0) - return data, sample_rate - - -def save_wav(path, data, sample_rate, channels_first=True): - """Save wav file without torchaudio""" - if channels_first: - data = data.transpose(1, 0) - scipy.io.wavfile.write(path, sample_rate, data.numpy()) +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' diff --git a/test/sox_io_backend/test_info.py b/test/sox_io_backend/test_info.py index 27ce064e03..25c2a67dc8 100644 --- a/test/sox_io_backend/test_info.py +++ b/test/sox_io_backend/test_info.py @@ -8,13 +8,13 @@ PytorchTestCase, skipIfNoExec, skipIfNoExtension, -) -from .common import ( - get_test_name, + sox_utils, get_wav_data, save_wav, ) -from . import sox_utils +from .common import ( + name_func, +) @skipIfNoExec('sox') @@ -24,7 +24,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file correctly""" duration = 1 @@ -40,7 +40,7 @@ def test_wav(self, dtype, sample_rate, num_channels): ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [4, 8, 16, 32], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" duration = 1 @@ -56,7 +56,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): [8000, 16000], [1, 2], [96, 128, 160, 192, 224, 256, 320], - )), name_func=get_test_name) + )), name_func=name_func) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.info` can check mp3 file correctly""" duration = 1 @@ -75,7 +75,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): [8000, 16000], [1, 2], list(range(9)), - )), name_func=get_test_name) + )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.info` can check flac file correctly""" duration = 1 @@ -93,7 +93,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): [8000, 16000], [1, 2], [-1, 0, 1, 2, 3, 3.6, 5, 10], - )), name_func=get_test_name) + )), name_func=name_func) def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.info` can check vorbis file correctly""" duration = 1 diff --git a/test/sox_io_backend/test_load.py b/test/sox_io_backend/test_load.py index bf25f16ee8..7c78efbf85 100644 --- a/test/sox_io_backend/test_load.py +++ b/test/sox_io_backend/test_load.py @@ -8,14 +8,14 @@ PytorchTestCase, skipIfNoExec, skipIfNoExtension, -) -from .common import ( - get_test_name, get_wav_data, load_wav, save_wav, + sox_utils, +) +from .common import ( + name_func, ) -from . import sox_utils class LoadTestBase(TempDirMixin, PytorchTestCase): @@ -129,7 +129,7 @@ class TestLoad(LoadTestBase): [8000, 16000], [1, 2], [False, True], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav(self, dtype, sample_rate, num_channels, normalize): """`sox_io_backend.load` can load wav format correctly.""" self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) @@ -139,7 +139,7 @@ def test_wav(self, dtype, sample_rate, num_channels, normalize): [16000], [2], [False], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav_large(self, dtype, sample_rate, num_channels, normalize): """`sox_io_backend.load` can load large wav file correctly.""" two_hours = 2 * 60 * 60 @@ -148,7 +148,7 @@ def test_wav_large(self, dtype, sample_rate, num_channels, normalize): @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [4, 8, 16, 32], - )), name_func=get_test_name) + )), name_func=name_func) def test_multiple_channels(self, dtype, num_channels): """`sox_io_backend.load` can load wav file with more than 2 channels.""" sample_rate = 8000 @@ -159,7 +159,7 @@ def test_multiple_channels(self, dtype, num_channels): [8000, 16000, 44100], [1, 2], [96, 128, 160, 192, 224, 256, 320], - )), name_func=get_test_name) + )), name_func=name_func) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.load` can load mp3 format correctly.""" self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) @@ -168,7 +168,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): [16000], [2], [128], - )), name_func=get_test_name) + )), name_func=name_func) def test_mp3_large(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.load` can load large mp3 file correctly.""" two_hours = 2 * 60 * 60 @@ -178,7 +178,7 @@ def test_mp3_large(self, sample_rate, num_channels, bit_rate): [8000, 16000], [1, 2], list(range(9)), - )), name_func=get_test_name) + )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.load` can load flac format correctly.""" self.assert_flac(sample_rate, num_channels, compression_level, duration=1) @@ -187,7 +187,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): [16000], [2], [0], - )), name_func=get_test_name) + )), name_func=name_func) def test_flac_large(self, sample_rate, num_channels, compression_level): """`sox_io_backend.load` can load large flac file correctly.""" two_hours = 2 * 60 * 60 @@ -197,7 +197,7 @@ def test_flac_large(self, sample_rate, num_channels, compression_level): [8000, 16000], [1, 2], [-1, 0, 1, 2, 3, 3.6, 5, 10], - )), name_func=get_test_name) + )), name_func=name_func) def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.load` can load vorbis format correctly.""" self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) @@ -206,7 +206,7 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): [16000], [2], [10], - )), name_func=get_test_name) + )), name_func=name_func) def test_vorbis_large(self, sample_rate, num_channels, quality_level): """`sox_io_backend.load` can load large vorbis file correctly.""" two_hours = 2 * 60 * 60 @@ -230,14 +230,14 @@ def setUp(self): @parameterized.expand(list(itertools.product( [0, 1, 10, 100, 1000], [-1, 1, 10, 100, 1000], - )), name_func=get_test_name) + )), name_func=name_func) def test_frame(self, frame_offset, num_frames): """num_frames and frame_offset correctly specify the region of data""" found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) frame_end = None if num_frames == -1 else frame_offset + num_frames self.assertEqual(found, self.original[:, frame_offset:frame_end]) - @parameterized.expand([(True, ), (False, )], name_func=get_test_name) + @parameterized.expand([(True, ), (False, )], name_func=name_func) def test_channels_first(self, channels_first): """channels_first swaps axes""" found, _ = sox_io_backend.load(self.path, channels_first=channels_first) diff --git a/test/sox_io_backend/test_roundtrip.py b/test/sox_io_backend/test_roundtrip.py index 0284ae6e57..2a051bebd5 100644 --- a/test/sox_io_backend/test_roundtrip.py +++ b/test/sox_io_backend/test_roundtrip.py @@ -8,10 +8,10 @@ PytorchTestCase, skipIfNoExec, skipIfNoExtension, + get_wav_data, ) from .common import ( - get_test_name, - get_wav_data, + name_func, ) @@ -23,7 +23,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav(self, dtype, sample_rate, num_channels): """save/load round trip should not degrade data for wav formats""" original = get_wav_data(dtype, num_channels, normalize=False) @@ -39,7 +39,7 @@ def test_wav(self, dtype, sample_rate, num_channels): [8000, 16000], [1, 2], list(range(9)), - )), name_func=get_test_name) + )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """save/load round trip should not degrade data for flac formats""" original = get_wav_data('float32', num_channels) diff --git a/test/sox_io_backend/test_save.py b/test/sox_io_backend/test_save.py index ac3395fb52..53588c456b 100644 --- a/test/sox_io_backend/test_save.py +++ b/test/sox_io_backend/test_save.py @@ -8,14 +8,14 @@ PytorchTestCase, skipIfNoExec, skipIfNoExtension, -) -from .common import ( - get_test_name, get_wav_data, load_wav, save_wav, + sox_utils, +) +from .common import ( + name_func, ) -from . import sox_utils class SaveTestBase(TempDirMixin, PytorchTestCase): @@ -176,7 +176,7 @@ class TestSave(SaveTestBase): ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.save` can save wav format.""" self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) @@ -185,7 +185,7 @@ def test_wav(self, dtype, sample_rate, num_channels): ['float32'], [16000], [2], - )), name_func=get_test_name) + )), name_func=name_func) def test_wav_large(self, dtype, sample_rate, num_channels): """`sox_io_backend.save` can save large wav file.""" two_hours = 2 * 60 * 60 * sample_rate @@ -194,7 +194,7 @@ def test_wav_large(self, dtype, sample_rate, num_channels): @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [4, 8, 16, 32], - )), name_func=get_test_name) + )), name_func=name_func) def test_multiple_channels(self, dtype, num_channels): """`sox_io_backend.save` can save wav with more than 2 channels.""" sample_rate = 8000 @@ -204,7 +204,7 @@ def test_multiple_channels(self, dtype, num_channels): [8000, 16000], [1, 2], [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], - )), name_func=get_test_name) + )), name_func=name_func) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.save` can save mp3 format.""" self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) @@ -213,7 +213,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): [16000], [2], [128], - )), name_func=get_test_name) + )), name_func=name_func) def test_mp3_large(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.save` can save large mp3 file.""" two_hours = 2 * 60 * 60 @@ -223,7 +223,7 @@ def test_mp3_large(self, sample_rate, num_channels, bit_rate): [8000, 16000], [1, 2], list(range(9)), - )), name_func=get_test_name) + )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.save` can save flac format.""" self.assert_flac(sample_rate, num_channels, compression_level, duration=1) @@ -232,7 +232,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): [16000], [2], [0], - )), name_func=get_test_name) + )), name_func=name_func) def test_flac_large(self, sample_rate, num_channels, compression_level): """`sox_io_backend.save` can save large flac file.""" two_hours = 2 * 60 * 60 @@ -242,7 +242,7 @@ def test_flac_large(self, sample_rate, num_channels, compression_level): [8000, 16000], [1, 2], [-1, 0, 1, 2, 3, 3.6, 5, 10], - )), name_func=get_test_name) + )), name_func=name_func) def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.save` can save vorbis format.""" self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20) @@ -255,7 +255,7 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): [16000], [2], [10], - )), name_func=get_test_name) + )), name_func=name_func) def test_vorbis_large(self, sample_rate, num_channels, quality_level): """`sox_io_backend.save` can save large vorbis file correctly.""" two_hours = 2 * 60 * 60 @@ -267,7 +267,7 @@ def test_vorbis_large(self, sample_rate, num_channels, quality_level): @skipIfNoExtension class TestSaveParams(TempDirMixin, PytorchTestCase): """Test the correctness of optional parameters of `sox_io_backend.save`""" - @parameterized.expand([(True, ), (False, )], name_func=get_test_name) + @parameterized.expand([(True, ), (False, )], name_func=name_func) def test_channels_first(self, channels_first): """channels_first swaps axes""" path = self.get_temp_path('data.wav') @@ -280,7 +280,7 @@ def test_channels_first(self, channels_first): @parameterized.expand([ 'float32', 'int32', 'int16', 'uint8' - ], name_func=get_test_name) + ], name_func=name_func) def test_noncontiguous(self, dtype): """Noncontiguous tensors are saved correctly""" path = self.get_temp_path('data.wav') diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index dc9a0fb120..fd7d6f8b64 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -10,14 +10,14 @@ TorchaudioTestCase, skipIfNoExec, skipIfNoExtension, -) -from .common import ( - get_test_name, get_wav_data, save_wav, load_wav, + sox_utils, +) +from .common import ( + name_func, ) -from . import sox_utils def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: @@ -47,7 +47,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], - )), name_func=get_test_name) + )), name_func=name_func) def test_info_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` is torchscript-able and returns the same result""" audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') @@ -71,7 +71,7 @@ def test_info_wav(self, dtype, sample_rate, num_channels): [1, 2], [False, True], [False, True], - )), name_func=get_test_name) + )), name_func=name_func) def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): """`sox_io_backend.load` is torchscript-able and returns the same result""" audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') @@ -94,7 +94,7 @@ def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_fi ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], - )), name_func=get_test_name) + )), name_func=name_func) def test_save_wav(self, dtype, sample_rate, num_channels): script_path = self.get_temp_path('save_func.zip') torch.jit.script(py_save_func).save(script_path) @@ -119,7 +119,7 @@ def test_save_wav(self, dtype, sample_rate, num_channels): [8000, 16000], [1, 2], list(range(9)), - )), name_func=get_test_name) + )), name_func=name_func) def test_save_flac(self, sample_rate, num_channels, compression_level): script_path = self.get_temp_path('save_func.zip') torch.jit.script(py_save_func).save(script_path) diff --git a/test/test_io.py b/test/test_io.py index f58f66ed11..3c5f00006d 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,11 +1,24 @@ import os import math +import shutil +import tempfile import unittest import torch import torchaudio -from .common_utils import BACKENDS, BACKENDS_MP3, create_temp_assets_dir +from .common_utils import BACKENDS, BACKENDS_MP3, get_asset_path + + +def create_temp_assets_dir(): + """ + Creates a temporary directory and moves all files from test/assets there. + Returns a Tuple[string, TemporaryDirectory] which is the folder path + and object. + """ + tmp_dir = tempfile.TemporaryDirectory() + shutil.copytree(get_asset_path(), os.path.join(tmp_dir.name, "assets")) + return tmp_dir.name, tmp_dir class Test_LoadSave(unittest.TestCase):