-
Notifications
You must be signed in to change notification settings - Fork 672
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
728 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestRoundTripIO(TempDirMixin, PytorchTestCase): | ||
"""save/load round trip should not degrade data for lossless formats""" | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
[False, True] | ||
)), name_func=get_test_name) | ||
def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize): | ||
"""save/load round trip should not degrade data for wav formats""" | ||
original = get_wav_data(dtype, num_channels, normalize=normalize) | ||
data = original | ||
for i in range(10): | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav') | ||
sox_io_backend.save(path, data, sample_rate) | ||
data, sr = sox_io_backend.load(path, normalize=normalize) | ||
assert sr == sample_rate | ||
self.assertEqual(original, data) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_roundtrip_flac(self, sample_rate, num_channels, compression_level): | ||
"""save/load round trip should not degrade data for flac formats""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') | ||
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level) | ||
original = sox_io_backend.load(path)[0] | ||
|
||
data = original | ||
for i in range(10): | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac') | ||
sox_io_backend.save(path, data, sample_rate) | ||
data, sr = sox_io_backend.load(path) | ||
assert sr == sample_rate | ||
self.assertEqual(original, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
load_wav, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class SaveTestBase(TempDirMixin, PytorchTestCase): | ||
def assert_wav(self, dtype, sample_rate, num_channels, num_frames): | ||
"""`sox_io_backend.save` can save wav format.""" | ||
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav') | ||
expected = get_wav_data(dtype, num_channels, num_frames=num_frames) | ||
sox_io_backend.save(path, expected, sample_rate) | ||
found = load_wav(path)[0] | ||
self.assertEqual(found, expected) | ||
|
||
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): | ||
"""`sox_io_backend.save` can save mp3 format. | ||
mp3 encoding introduces delay and boundary effects so | ||
we convert the resulting mp3 to wav and compare the results there | ||
| | ||
| 1. Generate original wav with Sox | ||
| | ||
v | ||
-------------- wav ---------------- | ||
| | | ||
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox | ||
| then save with torchaudio | | ||
v v | ||
mp3 mp3 | ||
| | | ||
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox | ||
| | | ||
v v | ||
wav wav | ||
| | | ||
| 2.3. load with scipy | 3.3. load with scipy | ||
| | | ||
v v | ||
tensor -------> compare <--------- tensor | ||
""" | ||
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav') | ||
mp3_path = f'{src_path}.mp3' | ||
wav_path = f'{mp3_path}.wav' | ||
mp3_path_sox = f'{src_path}.sox.mp3' | ||
wav_path_sox = f'{mp3_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=32, | ||
encoding='floating-point', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to mp3 with torchaudio | ||
sox_io_backend.save( | ||
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate) | ||
# 2.2. Convert the mp3 to wav with Sox | ||
sox_utils.convert_audio_file(mp3_path, wav_path) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to mp3 with SoX | ||
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) | ||
# 3.2. Convert the mp3 to wav with Sox | ||
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
self.assertEqual(found, expected) | ||
|
||
def assert_flac(self, sample_rate, num_channels, compression_level, duration): | ||
"""`sox_io_backend.save` can save flac format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav') | ||
flac_path = f'{src_path}.flac' | ||
wav_path = f'{flac_path}.wav' | ||
flac_path_sox = f'{src_path}.sox.flac' | ||
wav_path_sox = f'{flac_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=32, | ||
encoding='floating-point', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to flac with torchaudio | ||
sox_io_backend.save( | ||
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level) | ||
# 2.2. Convert the flac to wav with Sox | ||
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. | ||
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to flac with SoX | ||
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level) | ||
# 3.2. Convert the flac to wav with Sox | ||
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. | ||
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
self.assertEqual(found, expected) | ||
|
||
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration): | ||
"""`sox_io_backend.save` can save vorbis format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav') | ||
vorbis_path = f'{src_path}.vorbis' | ||
wav_path = f'{vorbis_path}.wav' | ||
vorbis_path_sox = f'{src_path}.sox.vorbis' | ||
wav_path_sox = f'{vorbis_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=16, | ||
encoding='signed-integer', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to vorbis with torchaudio | ||
sox_io_backend.save( | ||
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level) | ||
# 2.2. Convert the vorbis to wav with Sox | ||
sox_utils.convert_audio_file(vorbis_path, wav_path) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to vorbis with SoX | ||
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level) | ||
# 3.2. Convert the vorbis to wav with Sox | ||
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
# sox's vorbis encoding has some randomness, which cause small number of samples yields | ||
# higher descrepency than the others. | ||
# so we allow small portions of data to be outside of absolute torelance. | ||
atol = 1.0e-4 | ||
max_failure_allowed = 0.05 # this percent of samples are allowed to outside of atol. | ||
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel() | ||
if failure_ratio > max_failure_allowed: | ||
# it's failed and this will give a better error message. | ||
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6) | ||
|
||
def assert_vorbis(self, *args, **kwargs): | ||
# sox's vorbis encoding has some randomness, so we run tests multiple time | ||
max_retry = 5 | ||
error = None | ||
for _ in range(max_retry): | ||
try: | ||
self._assert_vorbis(*args, **kwargs) | ||
break | ||
except AssertionError as e: | ||
error = e | ||
else: | ||
raise error | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestSave(SaveTestBase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.save` can save wav format.""" | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32'], | ||
[16000], | ||
[2], | ||
)), name_func=get_test_name) | ||
def test_wav_large(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.save` can save large wav file.""" | ||
two_hours = 2 * 60 * 60 * sample_rate | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_multiple_channels(self, dtype, num_channels): | ||
"""`sox_io_backend.save` can save wav with more than 2 channels.""" | ||
sample_rate = 8000 | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.save` can save mp3 format.""" | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[128], | ||
)), name_func=get_test_name) | ||
def test_mp3_large(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.save` can save large mp3 file.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.save` can save flac format.""" | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[0], | ||
)), name_func=get_test_name) | ||
def test_flac_large(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.save` can save large flac file.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.save` can save vorbis format.""" | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
# note: torchaudio can load large vorbis file, but cannot save large volbis file | ||
# the following test causes Segmentation fault | ||
# | ||
''' | ||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_large(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.save` can save large vorbis file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) | ||
''' | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestSaveParams(TempDirMixin, PytorchTestCase): | ||
"""Test the correctness of optional parameters of `sox_io_backend.save`""" | ||
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) | ||
def test_channels_first(self, channels_first): | ||
"""channels_first swaps axes""" | ||
path = self.get_temp_path('test_channel_first_{channels_first}.wav') | ||
data = get_wav_data('int32', 2, channels_first=channels_first) | ||
sox_io_backend.save( | ||
path, data, 8000, channels_first=channels_first) | ||
found = load_wav(path)[0] | ||
expected = data if channels_first else data.transpose(1, 0) | ||
self.assertEqual(found, expected) | ||
|
||
@parameterized.expand([ | ||
'float32', 'int32', 'int16', 'uint8' | ||
], name_func=get_test_name) | ||
def test_noncontiguous(self, dtype): | ||
"""Noncontiguous tensors are saved correctly""" | ||
path = self.get_temp_path('test_uncontiguous_{dtype}.wav') | ||
expected = get_wav_data(dtype, 4)[::2, ::2] | ||
assert not expected.is_contiguous() | ||
sox_io_backend.save(path, expected, 8000) | ||
found = load_wav(path)[0] | ||
self.assertEqual(found, expected) | ||
|
||
@parameterized.expand([ | ||
'float32', 'int32', 'int16', 'uint8', | ||
]) | ||
def test_tensor_preserve(self, dtype): | ||
"""save function should not alter Tensor""" | ||
path = self.get_temp_path(f'test_preserve_{dtype}.wav') | ||
expected = get_wav_data(dtype, 4)[::2, ::2] | ||
|
||
data = expected.clone() | ||
sox_io_backend.save(path, data, 8000) | ||
|
||
self.assertEqual(data, expected) |
Oops, something went wrong.