-
Notifications
You must be signed in to change notification settings - Fork 672
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
662 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestRoundTripIO(TempDirMixin, PytorchTestCase): | ||
"""save/load round trip should not degrade data for lossless formats""" | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
[False, True] | ||
)), name_func=get_test_name) | ||
def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize): | ||
"""save/load round trip should not degrade data for wav formats""" | ||
original = get_wav_data(dtype, num_channels, normalize=normalize) | ||
data = original | ||
for i in range(10): | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav') | ||
sox_io_backend.save(path, data, sample_rate) | ||
data, sr = sox_io_backend.load(path, normalize=normalize) | ||
assert sr == sample_rate | ||
self.assertEqual(original, data) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_roundtrip_flac(self, sample_rate, num_channels, compression_level): | ||
"""save/load round trip should not degrade data for flac formats""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') | ||
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level) | ||
original = sox_io_backend.load(path)[0] | ||
|
||
data = original | ||
for i in range(10): | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac') | ||
sox_io_backend.save(path, data, sample_rate) | ||
data, sr = sox_io_backend.load(path) | ||
assert sr == sample_rate | ||
self.assertEqual(original, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
load_wav, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class SaveTestBase(TempDirMixin, PytorchTestCase): | ||
def assert_wav(self, dtype, sample_rate, num_channels, num_frames): | ||
"""`sox_io_backend.save` can save wav format.""" | ||
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav') | ||
expected = get_wav_data(dtype, num_channels, num_frames=num_frames) | ||
sox_io_backend.save(path, expected, sample_rate) | ||
found = load_wav(path)[0] | ||
self.assertEqual(found, expected) | ||
|
||
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): | ||
"""`sox_io_backend.save` can save mp3 format. | ||
mp3 encoding introduces delay and boundary effects so | ||
we convert the resulting mp3 to wav and compare the results there | ||
| | ||
| 1. Generate original wav with Sox | ||
| | ||
v | ||
-------------- wav ---------------- | ||
| | | ||
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox | ||
| then save with torchaudio | | ||
v v | ||
mp3 mp3 | ||
| | | ||
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox | ||
| | | ||
v v | ||
wav wav | ||
| | | ||
| 2.3. load with scipy | 3.3. load with scipy | ||
| | | ||
v v | ||
tensor -------> compare <--------- tensor | ||
""" | ||
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav') | ||
mp3_path = f'{src_path}.mp3' | ||
wav_path = f'{mp3_path}.wav' | ||
mp3_path_sox = f'{src_path}.sox.mp3' | ||
wav_path_sox = f'{mp3_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=32, | ||
encoding='floating-point', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to mp3 with torchaudio | ||
sox_io_backend.save( | ||
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate) | ||
# 2.2. Convert the mp3 to wav with Sox | ||
sox_utils.convert_audio_file(mp3_path, wav_path) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to mp3 with SoX | ||
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) | ||
# 3.2. Convert the mp3 to wav with Sox | ||
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
self.assertEqual(found, expected) | ||
|
||
def assert_flac(self, sample_rate, num_channels, compression_level, duration): | ||
"""`sox_io_backend.save` can save flac format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav') | ||
flac_path = f'{src_path}.flac' | ||
wav_path = f'{flac_path}.wav' | ||
flac_path_sox = f'{src_path}.sox.flac' | ||
wav_path_sox = f'{flac_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=32, | ||
encoding='floating-point', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to flac with torchaudio | ||
sox_io_backend.save( | ||
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level) | ||
# 2.2. Convert the flac to wav with Sox | ||
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. | ||
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to flac with SoX | ||
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level) | ||
# 3.2. Convert the flac to wav with Sox | ||
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. | ||
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
self.assertEqual(found, expected) | ||
|
||
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): | ||
"""`sox_io_backend.save` can save vorbis format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav') | ||
vorbis_path = f'{src_path}.vorbis' | ||
wav_path = f'{vorbis_path}.wav' | ||
vorbis_path_sox = f'{src_path}.sox.vorbis' | ||
wav_path_sox = f'{vorbis_path_sox}.wav' | ||
|
||
# 1. Generate original wav | ||
sox_utils.gen_audio_file( | ||
src_path, sample_rate, num_channels, | ||
bit_depth=16, | ||
encoding='signed-integer', | ||
duration=duration, | ||
) | ||
# 2.1. Convert the original wav to vorbis with torchaudio | ||
sox_io_backend.save( | ||
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level) | ||
# 2.2. Convert the vorbis to wav with Sox | ||
sox_utils.convert_audio_file(vorbis_path, wav_path) | ||
# 2.3. Load | ||
found = load_wav(wav_path)[0] | ||
|
||
# 3.1. Convert the original wav to vorbis with SoX | ||
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level) | ||
# 3.2. Convert the vorbis to wav with Sox | ||
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox) | ||
# 3.3. Load | ||
expected = load_wav(wav_path_sox)[0] | ||
|
||
# vorbis encoding seems to have some randomness, | ||
# so setting absolute torelance a bit higher | ||
self.assertEqual(found, expected, rtol=1.3e-06, atol=1e-03) | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestSave(SaveTestBase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.save` can save wav format.""" | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32'], | ||
[16000], | ||
[2], | ||
)), name_func=get_test_name) | ||
def test_wav_large(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.save` can save large wav file.""" | ||
two_hours = 2 * 60 * 60 * sample_rate | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_multiple_channels(self, dtype, num_channels): | ||
"""`sox_io_backend.save` can save wav with more than 2 channels.""" | ||
sample_rate = 8000 | ||
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.save` can save mp3 format.""" | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[128], | ||
)), name_func=get_test_name) | ||
def test_mp3_large(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.save` can save large mp3 file.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.save` can save flac format.""" | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[0], | ||
)), name_func=get_test_name) | ||
def test_flac_large(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.save` can save large flac file.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_mono(self, sample_rate, quality_level): | ||
"""`sox_io_backend.save` can save vorbis format.""" | ||
num_channels = 1 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_stereo(self, sample_rate, quality_level): | ||
"""`sox_io_backend.save` can save vorbis format.""" | ||
# note: though sample_rate16000 mostly works fine, it is omitted because | ||
# it gives slighly higher descrepency for few samples. | ||
# such as 762 out of 32000 not within atol=0.0001. | ||
num_channels = 2 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
# note: torchaudio can load large vorbis file, but cannot save large volbis file | ||
# the following test causes Segmentation fault | ||
# | ||
''' | ||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_large(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.save` can save large vorbis file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) | ||
''' | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestSaveParams(TempDirMixin, PytorchTestCase): | ||
"""Test the correctness of optional parameters of `sox_io_backend.save`""" | ||
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) | ||
def test_channels_first(self, channels_first): | ||
"""channels_first swaps axes""" | ||
path = self.get_temp_path('test_channel_first_{channels_first}.wav') | ||
data = get_wav_data('int32', 2, channels_first=channels_first) | ||
sox_io_backend.save( | ||
path, data, 8000, channels_first=channels_first) | ||
found = load_wav(path)[0] | ||
expected = data if channels_first else data.transpose(1, 0) | ||
self.assertEqual(found, expected) | ||
|
||
@parameterized.expand([ | ||
'float32', 'int32', 'int16', 'uint8' | ||
], name_func=get_test_name) | ||
def test_noncontiguous(self, dtype): | ||
"""Noncontiguous tensors are saved correctly""" | ||
path = self.get_temp_path('test_uncontiguous_{dtype}.wav') | ||
expected = get_wav_data(dtype, 4)[::2, ::2] | ||
assert not expected.is_contiguous() | ||
sox_io_backend.save(path, expected, 8000) | ||
found = load_wav(path)[0] | ||
self.assertEqual(found, expected) | ||
|
||
@parameterized.expand([ | ||
'float32', 'int32', 'int16', 'uint8', | ||
]) | ||
def test_tensor_preserve(self, dtype): | ||
"""save function should not alter Tensor""" | ||
path = self.get_temp_path(f'test_preserve_{dtype}.wav') | ||
expected = get_wav_data(dtype, 4)[::2, ::2] | ||
|
||
data = expected.clone() | ||
sox_io_backend.save(path, data, 8000) | ||
|
||
self.assertEqual(data, expected) |
Oops, something went wrong.