-
Notifications
You must be signed in to change notification settings - Fork 670
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
591 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from .. import common_utils | ||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
load_wav, | ||
save_wav, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class LoadTestBase(TempDirMixin, PytorchTestCase): | ||
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): | ||
"""`sox_io_backend.load` can load wav format correctly. | ||
Wav data loaded with sox_io backend should match those with scipy | ||
""" | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) | ||
expected = load_wav(path, normalize=normalize)[0] | ||
data, sr = sox_io_backend.load(path, normalize=normalize) | ||
assert sr == sample_rate | ||
self.assertEqual(data, expected) | ||
|
||
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): | ||
"""`sox_io_backend.load` can load mp3 format. | ||
mp3 encoding introduces delay and boundary effects so | ||
we create reference wav file from mp3 | ||
x | ||
| | ||
| Generate mp3 with Sox | ||
| | ||
v | ||
mp3 --- Convert to wav with Sox --> wav | ||
| | | ||
| load with torchaduio | load with scipy | ||
| | | ||
v v | ||
tensor --------> compare <--------- tensor | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=bit_rate, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06) | ||
|
||
def assert_flac(self, sample_rate, num_channels, compression_level, duration): | ||
"""`sox_io_backend.load` can load flac format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=compression_level, bit_depth=16, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): | ||
"""`sox_io_backend.load` can load vorbis format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=quality_level, bit_depth=16, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
|
||
@common_utils.skipIfNoExec('sox') | ||
@common_utils.skipIfNoExtension | ||
class TestLoad(LoadTestBase): | ||
"""Test the correctness of `sox_io_backend.load` for various formats""" | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
[False, True], | ||
)), name_func=get_test_name) | ||
def testload_wav(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load wav format correctly.""" | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['int16'], | ||
[16000], | ||
[2], | ||
[False], | ||
)), name_func=get_test_name) | ||
def testload_wav_large(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load large wav file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_load_multiple_channels(self, dtype, num_channels): | ||
"""`sox_io_backend.load` can load wav file with more than 2 channels.""" | ||
sample_rate = 8000 | ||
normalize = False | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000, 44100], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_load_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load mp3 format correctly.""" | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[128], | ||
)), name_func=get_test_name) | ||
def test_load_mp3_large(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load large mp3 file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_load_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load flac format correctly.""" | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[0], | ||
)), name_func=get_test_name) | ||
def test_load_flac_large(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load large flac file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_flac(sample_rate, num_channels, compression_level, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_load_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load vorbis format correctly.""" | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[10], | ||
)), name_func=get_test_name) | ||
def test_load_vorbis_large(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load large vorbis file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) | ||
|
||
|
||
@common_utils.skipIfNoExec('sox') | ||
@common_utils.skipIfNoExtension | ||
class TestLoadParams(TempDirMixin, PytorchTestCase): | ||
"""Test the correctness of frame parameters of `sox_io_backend.load`""" | ||
original = None | ||
path = None | ||
|
||
def setUp(self): | ||
super().setUp() | ||
sample_rate = 8000 | ||
self.original = get_wav_data('float32', num_channels=2) | ||
self.path = self.get_temp_path('test.wave') | ||
save_wav(self.path, self.original, sample_rate) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[0, 1, 10, 100, 1000], | ||
[-1, 1, 10, 100, 1000], | ||
)), name_func=get_test_name) | ||
def test_frame(self, frame_offset, num_frames): | ||
"""num_frames and frame_offset correctly specify the region of data""" | ||
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) | ||
frame_end = None if num_frames == -1 else frame_offset + num_frames | ||
self.assertEqual(found, self.original[:, frame_offset:frame_end]) | ||
|
||
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) | ||
def test_load_channel_first(self, channel_first): | ||
"""channel_first swaps axes""" | ||
found, _ = sox_io_backend.load(self.path, channel_first=channel_first) | ||
expected = self.original if channel_first else self.original.transpose(1, 0) | ||
self.assertEqual(found, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.