-
Notifications
You must be signed in to change notification settings - Fork 673
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
631 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,95 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import scipy.io.wavfile | ||
|
||
|
||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' | ||
|
||
|
||
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype == torch.float32: | ||
pass | ||
elif tensor.dtype == torch.int32: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 2147483647. | ||
tensor[tensor < 0] /= 2147483648. | ||
elif tensor.dtype == torch.int16: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 32767. | ||
tensor[tensor < 0] /= 32768. | ||
elif tensor.dtype == torch.uint8: | ||
tensor = tensor.to(torch.float32) - 128 | ||
tensor[tensor > 0] /= 127. | ||
tensor[tensor < 0] /= 128. | ||
return tensor | ||
|
||
|
||
def get_wav_data( | ||
dtype: str, | ||
num_channels: int, | ||
*, | ||
num_frames: Optional[int] = None, | ||
normalize: bool = True, | ||
channels_first: bool = True, | ||
): | ||
"""Generate linear signal of the given dtype and num_channels | ||
Data range is | ||
[-1.0, 1.0] for float32, | ||
[-2147483648, 2147483647] for int32 | ||
[-32768, 32767] for int16 | ||
[0, 255] for uint8 | ||
num_frames allow to change the linear interpolation parameter. | ||
Default values are 256 for uint8, else 1 << 16. | ||
1 << 16 as default is so that int16 value range is completely covered. | ||
""" | ||
dtype_ = getattr(torch, dtype) | ||
|
||
if num_frames is None: | ||
if dtype == 'uint8': | ||
num_frames = 256 | ||
else: | ||
num_frames = 1 << 16 | ||
|
||
if dtype == 'uint8': | ||
base = torch.linspace(0, 255, num_frames, dtype=dtype_) | ||
if dtype == 'float32': | ||
base = torch.linspace(-1., 1., num_frames, dtype=dtype_) | ||
if dtype == 'int32': | ||
# torch.linspace is broken when dtype=torch.int32 | ||
# https://github.com/pytorch/pytorch/issues/40118 | ||
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=torch.float32) | ||
base = base.to(torch.int32) | ||
base[0] = -2147483648 | ||
base[-1] = 2147483647 | ||
if dtype == 'int16': | ||
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) | ||
data = base.repeat([num_channels, 1]) | ||
if not channels_first: | ||
data = data.transpose(1, 0) | ||
if normalize: | ||
data = normalize_wav(data) | ||
return data | ||
|
||
|
||
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: | ||
"""Load wav file without torchaudio""" | ||
sample_rate, data = scipy.io.wavfile.read(path) | ||
data = torch.from_numpy(data.copy()) | ||
if data.ndim == 1: | ||
data = data.unsqueeze(1) | ||
if normalize: | ||
data = normalize_wav(data) | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
return data, sample_rate | ||
|
||
|
||
def save_wav(path, data, sample_rate, channels_first=True): | ||
"""Save wav file without torchaudio""" | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
scipy.io.wavfile.write(path, sample_rate, data.numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
load_wav, | ||
save_wav, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class LoadTestBase(TempDirMixin, PytorchTestCase): | ||
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): | ||
"""`sox_io_backend.load` can load wav format correctly. | ||
Wav data loaded with sox_io backend should match those with scipy | ||
""" | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) | ||
expected = load_wav(path, normalize=normalize)[0] | ||
data, sr = sox_io_backend.load(path, normalize=normalize) | ||
assert sr == sample_rate | ||
self.assertEqual(data, expected) | ||
|
||
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): | ||
"""`sox_io_backend.load` can load mp3 format. | ||
mp3 encoding introduces delay and boundary effects so | ||
we create reference wav file from mp3 | ||
x | ||
| | ||
| 1. Generate mp3 with Sox | ||
| | ||
v 2. Convert to wav with Sox | ||
mp3 ------------------------------> wav | ||
| | | ||
| 3. Load with torchaduio | 4. Load with scipy | ||
| | | ||
v v | ||
tensor ----------> x <----------- tensor | ||
5. Compare | ||
Underlying assumptions are; | ||
i. Convertion of mp3 to wav with Sox does not alter data. | ||
ii. Loading wav file with scipy is correct. | ||
By combining i & ii, step 2. and 4. allows to load reference mp3 data | ||
without using torchaudio | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') | ||
ref_path = f'{path}.wav' | ||
|
||
# 1. Generate mp3 with sox | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=bit_rate, duration=duration) | ||
# 2. Convert to wav with sox | ||
sox_utils.convert_audio_file(path, ref_path) | ||
# 3. Load mp3 with torchaudio | ||
data, sr = sox_io_backend.load(path) | ||
# 4. Load wav with scipy | ||
data_ref = load_wav(ref_path)[0] | ||
# 5. Compare | ||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06) | ||
|
||
def assert_flac(self, sample_rate, num_channels, compression_level, duration): | ||
"""`sox_io_backend.load` can load flac format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') | ||
ref_path = f'{path}.wav' | ||
|
||
# 1. Generate flac with sox | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=compression_level, bit_depth=16, duration=duration) | ||
# 2. Convert to wav with sox | ||
sox_utils.convert_audio_file(path, ref_path) | ||
# 3. Load flac with torchaudio | ||
data, sr = sox_io_backend.load(path) | ||
# 4. Load wav with scipy | ||
data_ref = load_wav(ref_path)[0] | ||
# 5. Compare | ||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): | ||
"""`sox_io_backend.load` can load vorbis format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') | ||
ref_path = f'{path}.wav' | ||
|
||
# 1. Generate vorbis with sox | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=quality_level, bit_depth=16, duration=duration) | ||
# 2. Convert to wav with sox | ||
sox_utils.convert_audio_file(path, ref_path) | ||
# 3. Load vorbis with torchaudio | ||
data, sr = sox_io_backend.load(path) | ||
# 4. Load wav with scipy | ||
data_ref = load_wav(ref_path)[0] | ||
# 5. Compare | ||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestLoad(LoadTestBase): | ||
"""Test the correctness of `sox_io_backend.load` for various formats""" | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
[False, True], | ||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load wav format correctly.""" | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['int16'], | ||
[16000], | ||
[2], | ||
[False], | ||
)), name_func=get_test_name) | ||
def test_wav_large(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load large wav file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_multiple_channels(self, dtype, num_channels): | ||
"""`sox_io_backend.load` can load wav file with more than 2 channels.""" | ||
sample_rate = 8000 | ||
normalize = False | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000, 44100], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load mp3 format correctly.""" | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[128], | ||
)), name_func=get_test_name) | ||
def test_mp3_large(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load large mp3 file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load flac format correctly.""" | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[0], | ||
)), name_func=get_test_name) | ||
def test_flac_large(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load large flac file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_flac(sample_rate, num_channels, compression_level, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load vorbis format correctly.""" | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_large(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load large vorbis file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestLoadParams(TempDirMixin, PytorchTestCase): | ||
"""Test the correctness of frame parameters of `sox_io_backend.load`""" | ||
original = None | ||
path = None | ||
|
||
def setUp(self): | ||
super().setUp() | ||
sample_rate = 8000 | ||
self.original = get_wav_data('float32', num_channels=2) | ||
self.path = self.get_temp_path('test.wave') | ||
save_wav(self.path, self.original, sample_rate) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[0, 1, 10, 100, 1000], | ||
[-1, 1, 10, 100, 1000], | ||
)), name_func=get_test_name) | ||
def test_frame(self, frame_offset, num_frames): | ||
"""num_frames and frame_offset correctly specify the region of data""" | ||
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) | ||
frame_end = None if num_frames == -1 else frame_offset + num_frames | ||
self.assertEqual(found, self.original[:, frame_offset:frame_end]) | ||
|
||
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) | ||
def test_channels_first(self, channels_first): | ||
"""channels_first swaps axes""" | ||
found, _ = sox_io_backend.load(self.path, channels_first=channels_first) | ||
expected = self.original if channels_first else self.original.transpose(1, 0) | ||
self.assertEqual(found, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.