Skip to content

Commit

Permalink
Add sox_io_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 25, 2020
1 parent e88aba5 commit 13999b0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
10 changes: 6 additions & 4 deletions test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from torchaudio.backend import sox_io_backend
import torchaudio
from parameterized import parameterized

from ..common_utils import (
Expand All @@ -21,11 +21,11 @@


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)
return torchaudio.info(filepath)


def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return sox_io_backend.load(
return torchaudio.load(
filepath, normalize=normalize, channels_first=channels_first)


Expand All @@ -36,13 +36,15 @@ def py_save_func(
channels_first: bool = True,
compression: Optional[float] = None,
):
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression)
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)


@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
Expand Down
8 changes: 6 additions & 2 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import torchaudio

from . import common_utils
Expand Down Expand Up @@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase)
backend_module = torchaudio.backend.sox_backend


@common_utils.skipIfNoExtension
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend


@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
Expand Down
18 changes: 18 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ class Test_LoadSave(unittest.TestCase):

def test_1_save(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath, False)

for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath_wav, True)
Expand Down Expand Up @@ -68,6 +72,8 @@ def _test_1_save(self, test_filepath, normalization):

def test_1_save_sine(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save_sine()
Expand Down Expand Up @@ -101,11 +107,15 @@ def _test_1_save_sine(self):

def test_2_load(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath, 278756)

for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath_wav, 276858)
Expand Down Expand Up @@ -142,6 +152,8 @@ def _test_2_load(self, test_filepath, length):

def test_2_load_nonormalization(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load_nonormalization(self.test_filepath, 278756)
Expand All @@ -159,6 +171,8 @@ def _test_2_load_nonormalization(self, test_filepath, length):

def test_3_load_and_save_is_identity(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_3_load_and_save_is_identity()
Expand Down Expand Up @@ -197,6 +211,8 @@ def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):

def test_4_load_partial(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_4_load_partial()
Expand Down Expand Up @@ -239,6 +255,8 @@ def _test_4_load_partial(self):

def test_5_get_info(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_5_get_info()
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import (
no_backend,
sox_backend,
sox_io_backend,
soundfile_backend,
)

Expand All @@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]:
backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'):
backends.append('sox')
backends.append('sox_io')
return backends


Expand All @@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None:
module = no_backend
elif backend == 'sox':
module = sox_backend
elif backend == 'sox_io':
module = sox_io_backend
elif backend == 'soundfile':
module = soundfile_backend
else:
Expand All @@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]:
return None
if torchaudio.load == sox_backend.load:
return 'sox'
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')

0 comments on commit 13999b0

Please sign in to comment.