From f85d96969df9c7ff9222f4df259cda99966d1d6e Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 18 Jun 2020 17:27:01 +0000 Subject: [PATCH] Add sox_io_backend --- test/sox_io/test_torchscript.py | 10 +++--- test/test_backend.py | 8 +++-- test/test_io.py | 18 +++++++++++ torchaudio/backend/utils.py | 6 ++++ torchaudio/csrc/sox_io.cpp | 56 +++++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 6 deletions(-) diff --git a/test/sox_io/test_torchscript.py b/test/sox_io/test_torchscript.py index 601a496b2d0..a0cabcb5c89 100644 --- a/test/sox_io/test_torchscript.py +++ b/test/sox_io/test_torchscript.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torchaudio.backend import sox_io_backend +import torchaudio from parameterized import parameterized from .. import common_utils @@ -19,11 +19,11 @@ def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: - return sox_io_backend.info(filepath) + return torchaudio.info(filepath) def py_load_func(filepath: str): - return sox_io_backend.load(filepath) + return torchaudio.load(filepath) def py_save_func( @@ -33,12 +33,14 @@ def py_save_func( channel_first: bool = False, compression: Optional[float] = None, ): - sox_io_backend.save(filepath, tensor, sample_rate, channel_first, compression) + torchaudio.save(filepath, tensor, sample_rate, channel_first, compression) @common_utils.skipIfNoExec('sox') @common_utils.skipIfNoExtension class SoxIO(TempDirMixin, TorchaudioTestCase): + backend = 'sox_io' + @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], diff --git a/test/test_backend.py b/test/test_backend.py index a13dd5088ab..6b67cb28986 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,5 +1,3 @@ -import unittest - import torchaudio from . import common_utils @@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase) backend_module = torchaudio.backend.sox_backend +@common_utils.skipIfNoExtension +class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase): + backend = 'sox_io' + backend_module = torchaudio.backend.sox_io_backend + + @common_utils.skipIfNoModule('soundfile') class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): backend = 'soundfile' diff --git a/test/test_io.py b/test/test_io.py index f58f66ed119..08427c9a60b 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -17,11 +17,15 @@ class Test_LoadSave(unittest.TestCase): def test_1_save(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save(self.test_filepath, False) for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save(self.test_filepath_wav, True) @@ -68,6 +72,8 @@ def _test_1_save(self, test_filepath, normalization): def test_1_save_sine(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save_sine() @@ -101,11 +107,15 @@ def _test_1_save_sine(self): def test_2_load(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load(self.test_filepath, 278756) for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load(self.test_filepath_wav, 276858) @@ -142,6 +152,8 @@ def _test_2_load(self, test_filepath, length): def test_2_load_nonormalization(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load_nonormalization(self.test_filepath, 278756) @@ -159,6 +171,8 @@ def _test_2_load_nonormalization(self, test_filepath, length): def test_3_load_and_save_is_identity(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_3_load_and_save_is_identity() @@ -197,6 +211,8 @@ def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2): def test_4_load_partial(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_4_load_partial() @@ -239,6 +255,8 @@ def _test_4_load_partial(self): def test_5_get_info(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_5_get_info() diff --git a/torchaudio/backend/utils.py b/torchaudio/backend/utils.py index d537f01daf1..cb53b3e02ff 100644 --- a/torchaudio/backend/utils.py +++ b/torchaudio/backend/utils.py @@ -7,6 +7,7 @@ from . import ( no_backend, sox_backend, + sox_io_backend, soundfile_backend, ) @@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]: backends.append('soundfile') if is_module_available('torchaudio._torchaudio'): backends.append('sox') + backends.append('sox_io') return backends @@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None: module = no_backend elif backend == 'sox': module = sox_backend + elif backend == 'sox_io': + module = sox_io_backend elif backend == 'soundfile': module = soundfile_backend else: @@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]: return None if torchaudio.load == sox_backend.load: return 'sox' + if torchaudio.load == sox_io_backend.load: + return 'sox_io' if torchaudio.load == soundfile_backend.load: return 'soundfile' raise ValueError('Unknown backend.') diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 9a10df0a81e..218af367ef2 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -31,6 +31,46 @@ struct SoxDescriptor { sox_format_t* fd_; }; +void printHighPrecision(const torch::Tensor& t) { + AT_DISPATCH_ALL_TYPES(t.scalar_type(), "print", [&] { + auto acc = t.accessor(); + for (int i = 0; i < acc.size(0); i++) { + for (int j = 0; j < acc.size(1); j++) { + auto val = acc[i][j]; + std::cout << val << " "; + } + std::cout << std::endl; + } + }); +} + +void printEnds(const torch::Tensor& t) { + std::cout << "++ first 10" << std::endl; + printHighPrecision(t.index({Slice(None, 10, None)})); + std::cout << "++ last 10" << std::endl; + printHighPrecision(t.index({Slice(-10, None, None)})); +} + +void printSignalInfo(const sox_signalinfo_t signal) { + std::cout << " - rate: " << signal.rate << std::endl + << " - channels: " << signal.channels << std::endl + << " - precision: " << signal.precision << std::endl + << " - length: " << signal.length << std::endl; + if (signal.mult) { + std::cout << " - mult: " << *(signal.mult) << std::endl; + } +} + +void printEncodingInfo(const sox_encodinginfo_t encoding) { + std::cout << " - encoding: " << encoding.encoding << std::endl + << " - bits_per_sample: " << encoding.bits_per_sample << std::endl + << " - compression: " << encoding.compression << std::endl + << " - reverse_bytes: " << encoding.reverse_bytes << std::endl + << " - reverse_nibbles: " << encoding.reverse_nibbles << std::endl + << " - reverse_bits: " << encoding.reverse_bits << std::endl + << " - opposite_endian: " << encoding.opposite_endian << std::endl; +} + } // namespace c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( @@ -45,6 +85,9 @@ c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( throw std::runtime_error("Error opening audio file"); } + printSignalInfo(sd->signal); + printEncodingInfo(sd->encoding); + return c10::make_intrusive<::torchaudio::SignalInfo>( static_cast(sd->signal.rate), static_cast(sd->signal.channels), @@ -186,6 +229,7 @@ torch::Tensor load_audio_file( option = option.dtype(torch::kFloat32); break; default: + std::cout << "encoding: " << ei.encoding << std::endl; throw std::runtime_error("Unsupported encoding."); } tensor = tensor.to(option); @@ -313,6 +357,10 @@ void save_audio_file( /*reverse_bits=*/sox_option_default, /*opposite_endian=*/sox_false}; + // std::cout << "++ Input info: " << std::endl; + // printSignalInfo(signal_info); + // printEncodingInfo(encoding_info); + SoxDescriptor sd(sox_open_write( file_name.c_str(), &signal_info, @@ -325,7 +373,13 @@ void save_audio_file( throw std::runtime_error("Error saving audio file: failed to open file."); } + std::cout << "++ Detected info: " << std::endl; + printSignalInfo(sd->signal); + printEncodingInfo(sd->encoding); + auto tensor_ = tensor; + // std::cout << "++ tensor before normalization" << std::endl; + // printEnds(tensor_); // de-normalization if (dtype == torch::kFloat32) { @@ -343,6 +397,8 @@ void save_audio_file( tensor_ -= 128; tensor_ *= 16777216; } + // std::cout << "++ tensor after nomalization" << std::endl; + // printEnds(tensor_); // Format & clean up if (channel_first) {