Skip to content

Commit

Permalink
Add sox_io_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 18, 2020
1 parent 8bfcf6e commit f85d969
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 6 deletions.
10 changes: 6 additions & 4 deletions test/sox_io/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from torchaudio.backend import sox_io_backend
import torchaudio
from parameterized import parameterized

from .. import common_utils
Expand All @@ -19,11 +19,11 @@


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)
return torchaudio.info(filepath)


def py_load_func(filepath: str):
return sox_io_backend.load(filepath)
return torchaudio.load(filepath)


def py_save_func(
Expand All @@ -33,12 +33,14 @@ def py_save_func(
channel_first: bool = False,
compression: Optional[float] = None,
):
sox_io_backend.save(filepath, tensor, sample_rate, channel_first, compression)
torchaudio.save(filepath, tensor, sample_rate, channel_first, compression)


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
backend = 'sox_io'

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
Expand Down
8 changes: 6 additions & 2 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import torchaudio

from . import common_utils
Expand Down Expand Up @@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase)
backend_module = torchaudio.backend.sox_backend


@common_utils.skipIfNoExtension
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend


@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
Expand Down
18 changes: 18 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ class Test_LoadSave(unittest.TestCase):

def test_1_save(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath, False)

for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath_wav, True)
Expand Down Expand Up @@ -68,6 +72,8 @@ def _test_1_save(self, test_filepath, normalization):

def test_1_save_sine(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save_sine()
Expand Down Expand Up @@ -101,11 +107,15 @@ def _test_1_save_sine(self):

def test_2_load(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath, 278756)

for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath_wav, 276858)
Expand Down Expand Up @@ -142,6 +152,8 @@ def _test_2_load(self, test_filepath, length):

def test_2_load_nonormalization(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load_nonormalization(self.test_filepath, 278756)
Expand All @@ -159,6 +171,8 @@ def _test_2_load_nonormalization(self, test_filepath, length):

def test_3_load_and_save_is_identity(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_3_load_and_save_is_identity()
Expand Down Expand Up @@ -197,6 +211,8 @@ def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):

def test_4_load_partial(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_4_load_partial()
Expand Down Expand Up @@ -239,6 +255,8 @@ def _test_4_load_partial(self):

def test_5_get_info(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_5_get_info()
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import (
no_backend,
sox_backend,
sox_io_backend,
soundfile_backend,
)

Expand All @@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]:
backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'):
backends.append('sox')
backends.append('sox_io')
return backends


Expand All @@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None:
module = no_backend
elif backend == 'sox':
module = sox_backend
elif backend == 'sox_io':
module = sox_io_backend
elif backend == 'soundfile':
module = soundfile_backend
else:
Expand All @@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]:
return None
if torchaudio.load == sox_backend.load:
return 'sox'
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')
56 changes: 56 additions & 0 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,46 @@ struct SoxDescriptor {
sox_format_t* fd_;
};

void printHighPrecision(const torch::Tensor& t) {
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "print", [&] {
auto acc = t.accessor<scalar_t, 2>();
for (int i = 0; i < acc.size(0); i++) {
for (int j = 0; j < acc.size(1); j++) {
auto val = acc[i][j];
std::cout << val << " ";
}
std::cout << std::endl;
}
});
}

void printEnds(const torch::Tensor& t) {
std::cout << "++ first 10" << std::endl;
printHighPrecision(t.index({Slice(None, 10, None)}));
std::cout << "++ last 10" << std::endl;
printHighPrecision(t.index({Slice(-10, None, None)}));
}

void printSignalInfo(const sox_signalinfo_t signal) {
std::cout << " - rate: " << signal.rate << std::endl
<< " - channels: " << signal.channels << std::endl
<< " - precision: " << signal.precision << std::endl
<< " - length: " << signal.length << std::endl;
if (signal.mult) {
std::cout << " - mult: " << *(signal.mult) << std::endl;
}
}

void printEncodingInfo(const sox_encodinginfo_t encoding) {
std::cout << " - encoding: " << encoding.encoding << std::endl
<< " - bits_per_sample: " << encoding.bits_per_sample << std::endl
<< " - compression: " << encoding.compression << std::endl
<< " - reverse_bytes: " << encoding.reverse_bytes << std::endl
<< " - reverse_nibbles: " << encoding.reverse_nibbles << std::endl
<< " - reverse_bits: " << encoding.reverse_bits << std::endl
<< " - opposite_endian: " << encoding.opposite_endian << std::endl;
}

} // namespace

c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(
Expand All @@ -45,6 +85,9 @@ c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(
throw std::runtime_error("Error opening audio file");
}

printSignalInfo(sd->signal);
printEncodingInfo(sd->encoding);

return c10::make_intrusive<::torchaudio::SignalInfo>(
static_cast<int64_t>(sd->signal.rate),
static_cast<int64_t>(sd->signal.channels),
Expand Down Expand Up @@ -186,6 +229,7 @@ torch::Tensor load_audio_file(
option = option.dtype(torch::kFloat32);
break;
default:
std::cout << "encoding: " << ei.encoding << std::endl;
throw std::runtime_error("Unsupported encoding.");
}
tensor = tensor.to(option);
Expand Down Expand Up @@ -313,6 +357,10 @@ void save_audio_file(
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};

// std::cout << "++ Input info: " << std::endl;
// printSignalInfo(signal_info);
// printEncodingInfo(encoding_info);

SoxDescriptor sd(sox_open_write(
file_name.c_str(),
&signal_info,
Expand All @@ -325,7 +373,13 @@ void save_audio_file(
throw std::runtime_error("Error saving audio file: failed to open file.");
}

std::cout << "++ Detected info: " << std::endl;
printSignalInfo(sd->signal);
printEncodingInfo(sd->encoding);

auto tensor_ = tensor;
// std::cout << "++ tensor before normalization" << std::endl;
// printEnds(tensor_);

// de-normalization
if (dtype == torch::kFloat32) {
Expand All @@ -343,6 +397,8 @@ void save_audio_file(
tensor_ -= 128;
tensor_ *= 16777216;
}
// std::cout << "++ tensor after nomalization" << std::endl;
// printEnds(tensor_);

// Format & clean up
if (channel_first) {
Expand Down

0 comments on commit f85d969

Please sign in to comment.