From e160e000d96444f199f7642b04af983b7cdfd76b Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 23 Jun 2020 19:26:58 +0000 Subject: [PATCH 1/3] Add save function --- test/common_utils.py | 36 +-- test/sox_io_backend/sox_utils.py | 13 +- test/sox_io_backend/test_info.py | 10 +- test/sox_io_backend/test_load.py | 14 +- test/sox_io_backend/test_roundtrip.py | 52 ++++ test/sox_io_backend/test_save.py | 304 ++++++++++++++++++++++++ test/sox_io_backend/test_torchscript.py | 75 +++++- torchaudio/backend/sox_io_backend.py | 62 ++++- torchaudio/csrc/register.cpp | 8 + torchaudio/csrc/sox_io.cpp | 52 +++- torchaudio/csrc/sox_io.h | 5 + torchaudio/csrc/sox_utils.cpp | 136 ++++++++++- torchaudio/csrc/sox_utils.h | 27 ++- 13 files changed, 755 insertions(+), 39 deletions(-) create mode 100644 test/sox_io_backend/test_roundtrip.py create mode 100644 test/sox_io_backend/test_save.py diff --git a/test/common_utils.py b/test/common_utils.py index 9f9a888259..953b1b8790 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -92,28 +92,34 @@ def set_audio_backend(backend): class TempDirMixin: """Mixin to provide easy access to temp dir""" temp_dir_ = None + base_temp_dir = None temp_dir = None - def setUp(self): - super().setUp() - self._init_temp_dir() + @classmethod + def setUpClass(cls): + super().setUpClass() + # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = 'TORCHAUDIO_TEST_TEMP_DIR' + if key in os.environ: + cls.base_temp_dir = os.environ[key] + else: + cls.temp_dir_ = tempfile.TemporaryDirectory() + cls.base_temp_dir = cls.temp_dir_.name - def tearDown(self): + @classmethod + def tearDownClass(cls): super().tearDownClass() - self._clean_up_temp_dir() + if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory): + cls.temp_dir_.cleanup() - def _init_temp_dir(self): - self.temp_dir_ = tempfile.TemporaryDirectory() - self.temp_dir = self.temp_dir_.name - - def _clean_up_temp_dir(self): - if self.temp_dir_ is not None: - self.temp_dir_.cleanup() - self.temp_dir_ = None - self.temp_dir = None + def setUp(self): + self.temp_dir = os.path.join(self.base_temp_dir, self.id()) def get_temp_path(self, *paths): - return os.path.join(self.temp_dir, *paths) + path = os.path.join(self.temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path class TestBaseMixin: diff --git a/test/sox_io_backend/sox_utils.py b/test/sox_io_backend/sox_utils.py index c30224158a..cd1c247b72 100644 --- a/test/sox_io_backend/sox_utils.py +++ b/test/sox_io_backend/sox_utils.py @@ -31,7 +31,16 @@ def gen_audio_file( 'Use get_wav_data and save_wav to generate wav file for accurate result.') command = [ 'sox', - '-V', # verbose + '-V3', # verbose + '-R', + # -R is supposed to be repeatable, though the implementation looks suspicious + # and not setting the seed to a fixed value. + # https://fossies.org/dox/sox-14.4.2/sox_8c_source.html + # search "sox_globals.repeatable" + ] + if bit_depth is not None: + command += ['--bits', str(bit_depth)] + command += [ '--rate', str(sample_rate), '--null', # no input '--channels', str(num_channels), @@ -60,7 +69,7 @@ def convert_audio_file( src_path, dst_path, *, bit_depth=None, compression=None): """Convert audio file with `sox` command.""" - command = ['sox', '-V', str(src_path)] + command = ['sox', '-V3', '-R', str(src_path)] if bit_depth is not None: command += ['--bits', str(bit_depth)] if compression is not None: diff --git a/test/sox_io_backend/test_info.py b/test/sox_io_backend/test_info.py index 91c13278f6..27ce064e03 100644 --- a/test/sox_io_backend/test_info.py +++ b/test/sox_io_backend/test_info.py @@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file correctly""" duration = 1 - path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') + path = self.get_temp_path('data.wav') data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) @@ -44,7 +44,7 @@ def test_wav(self, dtype, sample_rate, num_channels): def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" duration = 1 - path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') + path = self.get_temp_path('data.wav') data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) @@ -60,7 +60,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.info` can check mp3 file correctly""" duration = 1 - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3') + path = self.get_temp_path('data.mp3') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=bit_rate, duration=duration, @@ -79,7 +79,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.info` can check flac file correctly""" duration = 1 - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') + path = self.get_temp_path('data.flac') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=compression_level, duration=duration, @@ -97,7 +97,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.info` can check vorbis file correctly""" duration = 1 - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis') + path = self.get_temp_path('data.vorbis') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=quality_level, duration=duration, diff --git a/test/sox_io_backend/test_load.py b/test/sox_io_backend/test_load.py index a04550a666..bf25f16ee8 100644 --- a/test/sox_io_backend/test_load.py +++ b/test/sox_io_backend/test_load.py @@ -24,7 +24,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): Wav data loaded with sox_io backend should match those with scipy """ - path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') + path = self.get_temp_path('reference.wav') data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) expected = load_wav(path, normalize=normalize)[0] @@ -58,8 +58,8 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): By combining i & ii, step 2. and 4. allows to load reference mp3 data without using torchaudio """ - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') - ref_path = f'{path}.wav' + path = self.get_temp_path('1.original.mp3') + ref_path = self.get_temp_path('2.reference.wav') # 1. Generate mp3 with sox sox_utils.gen_audio_file( @@ -80,8 +80,8 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration): This test takes the same strategy as mp3 to compare the result """ - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') - ref_path = f'{path}.wav' + path = self.get_temp_path('1.original.flac') + ref_path = self.get_temp_path('2.reference.wav') # 1. Generate flac with sox sox_utils.gen_audio_file( @@ -102,8 +102,8 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): This test takes the same strategy as mp3 to compare the result """ - path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') - ref_path = f'{path}.wav' + path = self.get_temp_path('1.original.vorbis') + ref_path = self.get_temp_path('2.reference.wav') # 1. Generate vorbis with sox sox_utils.gen_audio_file( diff --git a/test/sox_io_backend/test_roundtrip.py b/test/sox_io_backend/test_roundtrip.py new file mode 100644 index 0000000000..0284ae6e57 --- /dev/null +++ b/test/sox_io_backend/test_roundtrip.py @@ -0,0 +1,52 @@ +import itertools + +from torchaudio.backend import sox_io_backend +from parameterized import parameterized + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoExtension, +) +from .common import ( + get_test_name, + get_wav_data, +) + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestRoundTripIO(TempDirMixin, PytorchTestCase): + """save/load round trip should not degrade data for lossless formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=get_test_name) + def test_wav(self, dtype, sample_rate, num_channels): + """save/load round trip should not degrade data for wav formats""" + original = get_wav_data(dtype, num_channels, normalize=False) + data = original + for i in range(10): + path = self.get_temp_path(f'{i}.wav') + sox_io_backend.save(path, data, sample_rate) + data, sr = sox_io_backend.load(path, normalize=False) + assert sr == sample_rate + self.assertEqual(original, data) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=get_test_name) + def test_flac(self, sample_rate, num_channels, compression_level): + """save/load round trip should not degrade data for flac formats""" + original = get_wav_data('float32', num_channels) + data = original + for i in range(10): + path = self.get_temp_path(f'{i}.flac') + sox_io_backend.save(path, data, sample_rate, compression=compression_level) + data, sr = sox_io_backend.load(path) + assert sr == sample_rate + self.assertEqual(original, data) diff --git a/test/sox_io_backend/test_save.py b/test/sox_io_backend/test_save.py new file mode 100644 index 0000000000..ec4b992c8f --- /dev/null +++ b/test/sox_io_backend/test_save.py @@ -0,0 +1,304 @@ +import itertools + +from torchaudio.backend import sox_io_backend +from parameterized import parameterized + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoExtension, +) +from .common import ( + get_test_name, + get_wav_data, + load_wav, + save_wav, +) +from . import sox_utils + + +class SaveTestBase(TempDirMixin, PytorchTestCase): + def assert_wav(self, dtype, sample_rate, num_channels, num_frames): + """`sox_io_backend.save` can save wav format.""" + path = self.get_temp_path('data.wav') + expected = get_wav_data(dtype, num_channels, num_frames=num_frames) + sox_io_backend.save(path, expected, sample_rate) + found, sr = load_wav(path) + assert sample_rate == sr + self.assertEqual(found, expected) + + def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): + """`sox_io_backend.save` can save mp3 format. + + mp3 encoding introduces delay and boundary effects so + we convert the resulting mp3 to wav and compare the results there + + | + | 1. Generate original wav file with SciPy + | + v + -------------- wav ---------------- + | | + | 2.1. load with scipy | 3.1. Convert to mp3 with Sox + | then save with torchaudio | + v v + mp3 mp3 + | | + | 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox + | | + v v + wav wav + | | + | 2.3. load with scipy | 3.3. load with scipy + | | + v v + tensor -------> compare <--------- tensor + + """ + src_path = self.get_temp_path('1.reference.wav') + mp3_path = self.get_temp_path('2.1.torchaudio.mp3') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + mp3_path_sox = self.get_temp_path('3.1.sox.mp3') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to mp3 with torchaudio + sox_io_backend.save( + mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate) + # 2.2. Convert the mp3 to wav with Sox + sox_utils.convert_audio_file(mp3_path, wav_path) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to mp3 with SoX + sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) + # 3.2. Convert the mp3 to wav with Sox + sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + self.assertEqual(found, expected) + + def assert_flac(self, sample_rate, num_channels, compression_level, duration): + """`sox_io_backend.save` can save flac format. + + This test takes the same strategy as mp3 to compare the result + """ + src_path = self.get_temp_path('1.reference.wav') + flc_path = self.get_temp_path('2.1.torchaudio.flac') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + flc_path_sox = self.get_temp_path('3.1.sox.flac') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to flac with torchaudio + sox_io_backend.save( + flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level) + # 2.2. Convert the flac to wav with Sox + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. + sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to flac with SoX + sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level) + # 3.2. Convert the flac to wav with Sox + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. + sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + self.assertEqual(found, expected) + + def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration): + """`sox_io_backend.save` can save vorbis format. + + This test takes the same strategy as mp3 to compare the result + """ + src_path = self.get_temp_path('1.reference.wav') + vbs_path = self.get_temp_path('2.1.torchaudio.vorbis') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + vbs_path_sox = self.get_temp_path('3.1.sox.vorbis') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to vorbis with torchaudio + sox_io_backend.save( + vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level) + # 2.2. Convert the vorbis to wav with Sox + sox_utils.convert_audio_file(vbs_path, wav_path) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to vorbis with SoX + sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level) + # 3.2. Convert the vorbis to wav with Sox + sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + # sox's vorbis encoding has some random boundary effect, which cause small number of + # samples yields higher descrepency than the others. + # so we allow small portions of data to be outside of absolute torelance. + # make sure to pass somewhat long duration + atol = 1.0e-4 + max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol. + failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel() + if failure_ratio > max_failure_allowed: + # it's failed and this will give a better error message. + self.assertEqual(found, expected, atol=atol, rtol=1.3e-6) + + def assert_vorbis(self, *args, **kwargs): + # sox's vorbis encoding has some randomness, so we run tests multiple time + max_retry = 5 + error = None + for _ in range(max_retry): + try: + self._assert_vorbis(*args, **kwargs) + break + except AssertionError as e: + error = e + else: + raise error + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestSave(SaveTestBase): + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=get_test_name) + def test_wav(self, dtype, sample_rate, num_channels): + """`sox_io_backend.save` can save wav format.""" + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterized.expand(list(itertools.product( + ['float32'], + [16000], + [2], + )), name_func=get_test_name) + def test_wav_large(self, dtype, sample_rate, num_channels): + """`sox_io_backend.save` can save large wav file.""" + two_hours = 2 * 60 * 60 * sample_rate + self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [4, 8, 16, 32], + )), name_func=get_test_name) + def test_multiple_channels(self, dtype, num_channels): + """`sox_io_backend.save` can save wav with more than 2 channels.""" + sample_rate = 8000 + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [96, 128, 160, 192, 224, 256, 320], + )), name_func=get_test_name) + def test_mp3(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.save` can save mp3 format.""" + self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [128], + )), name_func=get_test_name) + def test_mp3_large(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.save` can save large mp3 file.""" + two_hours = 2 * 60 * 60 + self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=get_test_name) + def test_flac(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.save` can save flac format.""" + self.assert_flac(sample_rate, num_channels, compression_level, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [0], + )), name_func=get_test_name) + def test_flac_large(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.save` can save large flac file.""" + two_hours = 2 * 60 * 60 + self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + )), name_func=get_test_name) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.save` can save vorbis format.""" + self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20) + + # note: torchaudio can load large vorbis file, but cannot save large volbis file + # the following test causes Segmentation fault + # + ''' + @parameterized.expand(list(itertools.product( + [16000], + [2], + [10], + )), name_func=get_test_name) + def test_vorbis_large(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.save` can save large vorbis file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) + ''' + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestSaveParams(TempDirMixin, PytorchTestCase): + """Test the correctness of optional parameters of `sox_io_backend.save`""" + @parameterized.expand([(True, ), (False, )], name_func=get_test_name) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + path = self.get_temp_path('data.wav') + data = get_wav_data('int32', 2, channels_first=channels_first) + sox_io_backend.save( + path, data, 8000, channels_first=channels_first) + found = load_wav(path)[0] + expected = data if channels_first else data.transpose(1, 0) + self.assertEqual(found, expected) + + @parameterized.expand([ + 'float32', 'int32', 'int16', 'uint8' + ], name_func=get_test_name) + def test_noncontiguous(self, dtype): + """Noncontiguous tensors are saved correctly""" + path = self.get_temp_path('data.wav') + expected = get_wav_data(dtype, 4)[::2, ::2] + assert not expected.is_contiguous() + sox_io_backend.save(path, expected, 8000) + found = load_wav(path)[0] + self.assertEqual(found, expected) + + @parameterized.expand([ + 'float32', 'int32', 'int16', 'uint8', + ]) + def test_tensor_preserve(self, dtype): + """save function should not alter Tensor""" + path = self.get_temp_path('data.wav') + expected = get_wav_data(dtype, 4)[::2, ::2] + + data = expected.clone() + sox_io_backend.save(path, data, 8000) + + self.assertEqual(data, expected) diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index c6e9df41e1..dc9a0fb120 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -1,4 +1,5 @@ import itertools +from typing import Optional import torch from torchaudio.backend import sox_io_backend @@ -13,8 +14,10 @@ from .common import ( get_test_name, get_wav_data, - save_wav + save_wav, + load_wav, ) +from . import sox_utils def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: @@ -26,6 +29,16 @@ def py_load_func(filepath: str, normalize: bool, channels_first: bool): filepath, normalize=normalize, channels_first=channels_first) +def py_save_func( + filepath: str, + tensor: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, +): + sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression) + + @skipIfNoExec('sox') @skipIfNoExtension class SoxIO(TempDirMixin, TorchaudioTestCase): @@ -41,7 +54,7 @@ def test_info_wav(self, dtype, sample_rate, num_channels): data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) save_wav(audio_path, data, sample_rate) - script_path = self.get_temp_path('info_func') + script_path = self.get_temp_path('info_func.zip') torch.jit.script(py_info_func).save(script_path) ts_info_func = torch.jit.load(script_path) @@ -65,7 +78,7 @@ def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_fi data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) save_wav(audio_path, data, sample_rate) - script_path = self.get_temp_path('load_func') + script_path = self.get_temp_path('load_func.zip') torch.jit.script(py_load_func).save(script_path) ts_load_func = torch.jit.load(script_path) @@ -76,3 +89,59 @@ def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_fi self.assertEqual(py_sr, ts_sr) self.assertEqual(py_data, ts_data) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=get_test_name) + def test_save_wav(self, dtype, sample_rate, num_channels): + script_path = self.get_temp_path('save_func.zip') + torch.jit.script(py_save_func).save(script_path) + ts_save_func = torch.jit.load(script_path) + + expected = get_wav_data(dtype, num_channels) + py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') + ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') + + py_save_func(py_path, expected, sample_rate, True, None) + ts_save_func(ts_path, expected, sample_rate, True, None) + + py_data, py_sr = load_wav(py_path) + ts_data, ts_sr = load_wav(ts_path) + + self.assertEqual(sample_rate, py_sr) + self.assertEqual(sample_rate, ts_sr) + self.assertEqual(expected, py_data) + self.assertEqual(expected, ts_data) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=get_test_name) + def test_save_flac(self, sample_rate, num_channels, compression_level): + script_path = self.get_temp_path('save_func.zip') + torch.jit.script(py_save_func).save(script_path) + ts_save_func = torch.jit.load(script_path) + + expected = get_wav_data('float32', num_channels) + py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') + ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') + + py_save_func(py_path, expected, sample_rate, True, compression_level) + ts_save_func(ts_path, expected, sample_rate, True, compression_level) + + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. + py_path_wav = f'{py_path}.wav' + ts_path_wav = f'{ts_path}.wav' + sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32) + sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32) + + py_data, py_sr = load_wav(py_path_wav, normalize=True) + ts_data, ts_sr = load_wav(ts_path_wav, normalize=True) + + self.assertEqual(sample_rate, py_sr) + self.assertEqual(sample_rate, ts_sr) + self.assertEqual(expected, py_data) + self.assertEqual(expected, ts_data) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index a9bcffdd3a..f9d1b7ea68 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import torch from torchaudio._internal import ( @@ -77,4 +77,64 @@ def load( return signal.get_tensor(), signal.get_sample_rate() +@_mod_utils.requires_module('torchaudio._torchaudio') +def save( + filepath: str, + tensor: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + frames_per_chunk: int = 65536, +): + """Save audio data to file. + + Supported formats are; + - WAV + - 32-bit floating-point + - 32-bit signed integer + - 16-bit signed integer + - 8-bit unsigned integer + - MP3 + - FLAC + - OGG/VORBIS + + Note: + Currently torchaudio's binary release does not include codecs library required to handle + OGG/VORBIS and OPUS. To use these formats, you need to build torchaudio from source. + Refer to README for this. + + Args: + filepath: Path to save file. + tensor: Audio data to save. must be 2D tensor. + sample_rate: sampling rate + channels_first: If True, the given tensor is interpreted as ``[channel, time]``. + frame_offset: Number of frames to skip before start reading data. + num_frames: Maximum number of frames to read. If there is not enough frames in + the given audio, this function does NOT raise an error. + normalize: When True and input file is integer WAV, the resulting Tensor type + becomes ``float32`` and values are normalized to ``[-1.0, 1.0]``. + This argument has no effect for other formats. + channels_first: When True, the returned Tensor has dimension [channel, time]. + compression: Used for formats other than WAV. This corresponds to ``-C`` option + of ``sox`` command. + See the detail at http://sox.sourceforge.net/soxformat.html. + - MP3: bitrate [kbps]. + - FLAC: compression level. Whole number from 0 to 8. 8 is default and highest + compression. + - OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest + quality. Default value is 3. + """ + if compression is None: + compression = 0. + ext = str(filepath)[-3:].lower() + if ext == 'mp3': + compression = 128.2 + elif ext == 'flac': + compression = 8. + elif ext in ['ogg', 'vorbis']: + compression = 3. + signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first) + torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk) + + load_wav = load diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 44f7826e5a..538c3f0ea0 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -46,6 +46,14 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op( decltype(sox_io::load_audio_file), &sox_io::load_audio_file>()); +static auto registerSaveAudioFile = torch::RegisterOperators().op( + torch::RegisterOperators::options() + .schema( + "torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()") + .catchAllKernel< + decltype(sox_io::save_audio_file), + &sox_io::save_audio_file>()); + //////////////////////////////////////////////////////////////////////////////// // sox_effects.h //////////////////////////////////////////////////////////////////////////////// diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 349e65c97d..1870dd8c9f 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -15,7 +15,7 @@ c10::intrusive_ptr get_info(const std::string& path) { /*encoding=*/nullptr, /*filetype=*/nullptr)); - if (sf.get() == nullptr) { + if (static_cast(sf) == nullptr) { throw std::runtime_error("Error opening audio file"); } @@ -52,7 +52,7 @@ c10::intrusive_ptr load_audio_file( const int64_t num_total_samples = sf->signal.length; const int64_t sample_start = sf->signal.channels * frame_offset; - if (sox_seek(sf.get(), sample_start, 0) == SOX_EOF) { + if (sox_seek(sf, sample_start, 0) == SOX_EOF) { throw std::runtime_error("Error reading audio file: offset past EOF."); } @@ -79,7 +79,7 @@ c10::intrusive_ptr load_audio_file( // Read samples into buffer std::vector buffer; buffer.reserve(max_samples); - const int64_t num_samples = sox_read(sf.get(), buffer.data(), max_samples); + const int64_t num_samples = sox_read(sf, buffer.data(), max_samples); if (num_samples == 0) { throw std::runtime_error( "Error reading audio file: empty file or read operation failed."); @@ -100,5 +100,51 @@ c10::intrusive_ptr load_audio_file( tensor, static_cast(sf->signal.rate), channels_first); } +void save_audio_file( + const std::string& file_name, + const c10::intrusive_ptr& signal, + const double compression, + const int64_t frames_per_chunk) { + const auto tensor = signal->getTensor(); + const auto sample_rate = signal->getSampleRate(); + const auto channels_first = signal->getChannelsFirst(); + + validate_input_tensor(tensor); + + const auto filetype = get_filetype(file_name); + const auto signal_info = + get_signalinfo(tensor, sample_rate, channels_first, filetype); + const auto encoding_info = + get_encodinginfo(filetype, tensor.dtype(), compression); + + SoxFormat sf(sox_open_write( + file_name.c_str(), + &signal_info, + &encoding_info, + /*filetype=*/filetype.c_str(), + /*oob=*/nullptr, + /*overwrite_permitted=*/nullptr)); + + if (static_cast(sf) == nullptr) { + throw std::runtime_error("Error saving audio file: failed to open file."); + } + + auto tensor_ = tensor; + if (channels_first) { + tensor_ = tensor_.t(); + } + + for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) { + auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()}); + chunk = unnormalize_wav(chunk).contiguous(); + + const size_t numel = chunk.numel(); + if (sox_write(sf, chunk.data_ptr(), numel) != numel) { + throw std::runtime_error( + "Error saving audio file: failed to write the entier buffer."); + } + } +} + } // namespace sox_io } // namespace torchaudio diff --git a/torchaudio/csrc/sox_io.h b/torchaudio/csrc/sox_io.h index 3751f22cf5..310687bb7d 100644 --- a/torchaudio/csrc/sox_io.h +++ b/torchaudio/csrc/sox_io.h @@ -17,6 +17,11 @@ c10::intrusive_ptr load_audio_file( const bool normalize = true, const bool channels_first = true); +void save_audio_file( + const std::string& file_name, + const c10::intrusive_ptr& signal, + const double compression = 0., + const int64_t frames_per_chunk = 65536); } // namespace sox_io } // namespace torchaudio diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp index 4a7b3d8014..c1fd8383a8 100644 --- a/torchaudio/csrc/sox_utils.cpp +++ b/torchaudio/csrc/sox_utils.cpp @@ -32,12 +32,12 @@ SoxFormat::~SoxFormat() { sox_format_t* SoxFormat::operator->() const noexcept { return fd_; } -sox_format_t* SoxFormat::get() const noexcept { +SoxFormat::operator sox_format_t*() const noexcept { return fd_; } void validate_input_file(const SoxFormat& sf) { - if (sf.get() == nullptr) { + if (static_cast(sf) == nullptr) { throw std::runtime_error("Error loading audio file: failed to open file."); } if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { @@ -48,6 +48,23 @@ void validate_input_file(const SoxFormat& sf) { } } +void validate_input_tensor(const torch::Tensor tensor) { + if (!tensor.device().is_cpu()) { + throw std::runtime_error("Input tensor has to be on CPU."); + } + + if (tensor.ndimension() != 2) { + throw std::runtime_error("Input tensor has to be 2D."); + } + + const auto dtype = tensor.dtype(); + if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 || + dtype == torch::kInt16 || dtype == torch::kUInt8)) { + throw std::runtime_error( + "Input tensor has to be one of float32, int32, int16 or uint8 type."); + } +} + caffe2::TypeMeta get_dtype( const sox_encoding_t encoding, const unsigned precision) { @@ -109,5 +126,120 @@ torch::Tensor convert_to_tensor( return t.contiguous(); } +torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) { + const auto dtype = input_tensor.dtype(); + auto tensor = input_tensor; + if (dtype == torch::kFloat32) { + double multi_pos = 2147483647.; + double multi_neg = -2147483648.; + auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg; + tensor = tensor.to(torch::dtype(torch::kFloat64)); + tensor *= mult; + tensor.clamp_(multi_neg, multi_pos); + tensor = tensor.to(torch::dtype(torch::kInt32)); + } else if (dtype == torch::kInt32) { + // already denormalized + } else if (dtype == torch::kInt16) { + tensor = tensor.to(torch::dtype(torch::kInt32)); + tensor *= ((tensor != 0) * 65536); + } else if (dtype == torch::kUInt8) { + tensor = tensor.to(torch::dtype(torch::kInt32)); + tensor -= 128; + tensor *= 16777216; + } else { + throw std::runtime_error("Unexpected dtype."); + } + return tensor; +} + +const std::string get_filetype(const std::string path) { + std::string ext = path.substr(path.find_last_of(".") + 1); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + return ext; +} + +sox_encoding_t get_encoding( + const std::string filetype, + const caffe2::TypeMeta dtype) { + if (filetype == "mp3") + return SOX_ENCODING_MP3; + if (filetype == "flac") + return SOX_ENCODING_FLAC; + if (filetype == "ogg" || filetype == "vorbis") + return SOX_ENCODING_VORBIS; + if (filetype == "wav") { + if (dtype == torch::kUInt8) + return SOX_ENCODING_UNSIGNED; + if (dtype == torch::kInt16) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kInt32) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kFloat32) + return SOX_ENCODING_FLOAT; + throw std::runtime_error("Unsupported dtype."); + } + throw std::runtime_error("Unsupported file type."); +} + +unsigned get_precision( + const std::string filetype, + const caffe2::TypeMeta dtype) { + if (filetype == "mp3") + return SOX_UNSPEC; + if (filetype == "flac") + return 24; + if (filetype == "ogg" || filetype == "vorbis") + return SOX_UNSPEC; + if (filetype == "wav") { + if (dtype == torch::kUInt8) + return 8; + if (dtype == torch::kInt16) + return 16; + if (dtype == torch::kInt32) + return 32; + if (dtype == torch::kFloat32) + return 32; + throw std::runtime_error("Unsupported dtype."); + } + throw std::runtime_error("Unsupported file type."); +} + +sox_signalinfo_t get_signalinfo( + const torch::Tensor& tensor, + const int64_t sample_rate, + const bool channels_first, + const std::string filetype) { + return sox_signalinfo_t{ + /*rate=*/static_cast(sample_rate), + /*channels=*/static_cast(tensor.size(channels_first ? 0 : 1)), + /*precision=*/get_precision(filetype, tensor.dtype()), + /*length=*/static_cast(tensor.numel())}; +} + +sox_encodinginfo_t get_encodinginfo( + const std::string filetype, + const caffe2::TypeMeta dtype, + const double compression) { + const double compression_ = [&]() { + if (filetype == "mp3") + return compression; + if (filetype == "flac") + return compression; + if (filetype == "ogg" || filetype == "vorbis") + return compression; + if (filetype == "wav") + return 0.; + throw std::runtime_error("Unsupported file type."); + }(); + + return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), + /*bits_per_sample=*/get_precision(filetype, dtype), + /*compression=*/compression_, + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} + } // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox_utils.h b/torchaudio/csrc/sox_utils.h index cc61d67c77..665187c840 100644 --- a/torchaudio/csrc/sox_utils.h +++ b/torchaudio/csrc/sox_utils.h @@ -31,7 +31,7 @@ struct SoxFormat { SoxFormat& operator=(SoxFormat&& other) = delete; ~SoxFormat(); sox_format_t* operator->() const noexcept; - sox_format_t* get() const noexcept; + operator sox_format_t*() const noexcept; private: sox_format_t* fd_; @@ -41,6 +41,10 @@ struct SoxFormat { /// Verify that input file is found, has known encoding, and not empty void validate_input_file(const SoxFormat& sf); +/// +/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 +void validate_input_tensor(const torch::Tensor); + /// /// Get target dtype for the given encoding and precision. caffe2::TypeMeta get_dtype( @@ -70,6 +74,27 @@ torch::Tensor convert_to_tensor( const bool normalize, const bool channels_first); +/// +/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox +/// conversion. +torch::Tensor unnormalize_wav(const torch::Tensor); + +/// Extract extension from file path +const std::string get_filetype(const std::string path); + +/// Get sox_signalinfo_t for passing a torch::Tensor object. +sox_signalinfo_t get_signalinfo( + const torch::Tensor& tensor, + const int64_t sample_rate, + const bool channels_first, + const std::string filetype); + +/// Get sox_encofinginfo_t for saving audoi file +sox_encodinginfo_t get_encodinginfo( + const std::string filetype, + const caffe2::TypeMeta dtype, + const double compression); + } // namespace sox_utils } // namespace torchaudio #endif From 1fb15b376a4bb02315cd6cb7317716eadc1e9800 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 29 Jun 2020 19:23:58 +0000 Subject: [PATCH 2/3] Fix docstring --- test/sox_io_backend/test_save.py | 2 +- torchaudio/backend/sox_io_backend.py | 32 ++++++++-------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/test/sox_io_backend/test_save.py b/test/sox_io_backend/test_save.py index ec4b992c8f..ac3395fb52 100644 --- a/test/sox_io_backend/test_save.py +++ b/test/sox_io_backend/test_save.py @@ -203,7 +203,7 @@ def test_multiple_channels(self, dtype, num_channels): @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], - [96, 128, 160, 192, 224, 256, 320], + [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], )), name_func=get_test_name) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.save` can save mp3 format.""" diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index f9d1b7ea68..6517eb9fd6 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -25,12 +25,7 @@ def load( This function can handle all the codecs that underlying libsox can handle, however note the followings. - Note 1: - Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS. - If you need other formats, you need to build torchaudio from source with libsox and - the corresponding codecs. Refer to README for this. - - Note 2: + Note: This function is tested on the following formats; - WAV - 32-bit floating-point @@ -98,37 +93,28 @@ def save( - FLAC - OGG/VORBIS - Note: - Currently torchaudio's binary release does not include codecs library required to handle - OGG/VORBIS and OPUS. To use these formats, you need to build torchaudio from source. - Refer to README for this. - Args: filepath: Path to save file. tensor: Audio data to save. must be 2D tensor. sample_rate: sampling rate channels_first: If True, the given tensor is interpreted as ``[channel, time]``. - frame_offset: Number of frames to skip before start reading data. - num_frames: Maximum number of frames to read. If there is not enough frames in - the given audio, this function does NOT raise an error. - normalize: When True and input file is integer WAV, the resulting Tensor type - becomes ``float32`` and values are normalized to ``[-1.0, 1.0]``. - This argument has no effect for other formats. - channels_first: When True, the returned Tensor has dimension [channel, time]. compression: Used for formats other than WAV. This corresponds to ``-C`` option of ``sox`` command. See the detail at http://sox.sourceforge.net/soxformat.html. - - MP3: bitrate [kbps]. - - FLAC: compression level. Whole number from 0 to 8. 8 is default and highest - compression. + - MP3: Either bitrate [kbps] with quality factor, such as ``128.2`` or + VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5`` + - FLAC: compression level. Whole number from ``0`` to ``8``. + ``8`` is default and highest compression. - OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest - quality. Default value is 3. + quality. Default: ``3``. + frames_per_chunk: The number of frames to process (convert to ``int32`` internally + then write to file) at a time. """ if compression is None: compression = 0. ext = str(filepath)[-3:].lower() if ext == 'mp3': - compression = 128.2 + compression = -4.5 elif ext == 'flac': compression = 8. elif ext in ['ogg', 'vorbis']: From 0130ca779dd4d27c85199973f05eaab457a4859f Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 29 Jun 2020 21:06:46 +0000 Subject: [PATCH 3/3] Raise error for unsupported file type --- torchaudio/backend/sox_io_backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 6517eb9fd6..2fef97917f 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -111,14 +111,17 @@ def save( then write to file) at a time. """ if compression is None: - compression = 0. ext = str(filepath)[-3:].lower() - if ext == 'mp3': + if ext == 'wav': + compression = 0. + elif ext == 'mp3': compression = -4.5 elif ext == 'flac': compression = 8. elif ext in ['ogg', 'vorbis']: compression = 3. + else: + raise RuntimeError(f'Unsupported file type: "{ext}"') signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first) torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk)