Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchScript-able "save" func to sox_io backend #732

Merged
merged 3 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,34 @@ def set_audio_backend(backend):
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name

def tearDown(self):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
self._clean_up_temp_dir()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def setUp(self):
self.temp_dir = os.path.join(self.base_temp_dir, self.id())

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path


class TestBaseMixin:
Expand Down
13 changes: 11 additions & 2 deletions test/sox_io_backend/sox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ def gen_audio_file(
'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [
'sox',
'-V', # verbose
'-V3', # verbose
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
command += [
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
Expand Down Expand Up @@ -60,7 +69,7 @@ def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)]
command = ['sox', '-V3', '-R', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
Expand Down
10 changes: 5 additions & 5 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -44,7 +44,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -60,7 +60,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
path = self.get_temp_path('data.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
Expand All @@ -79,7 +79,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
path = self.get_temp_path('data.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
Expand All @@ -97,7 +97,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
path = self.get_temp_path('data.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
Expand Down
14 changes: 7 additions & 7 deletions test/sox_io_backend/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):

Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
Expand Down Expand Up @@ -58,8 +58,8 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
Expand All @@ -80,8 +80,8 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):

This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.flac')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate flac with sox
sox_utils.gen_audio_file(
Expand All @@ -102,8 +102,8 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):

This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
Expand Down
52 changes: 52 additions & 0 deletions test/sox_io_backend/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import itertools

from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
)


@skipIfNoExec('sox')
@skipIfNoExtension
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
self.assertEqual(original, data)
Loading