Skip to content

Commit

Permalink
Update based on the first review
Browse files Browse the repository at this point in the history
 - Fix typo
 - Allow all the codecs and update docstring
 - minor tweaks
  • Loading branch information
mthrok committed Jun 24, 2020
1 parent 18fcf3b commit 577dfcf
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 39 deletions.
4 changes: 2 additions & 2 deletions test/sox_io_backend/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
v 2. Convert to wav with Sox
mp3 ------------------------------> wav
| |
| 3. Load with torchaduio | 4. Load with scipy
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Convertion of mp3 to wav with Sox does not alter data.
i. Conversion of mp3 to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference mp3 data
Expand Down
42 changes: 24 additions & 18 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@ def load(
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
Supported formats are;
- WAV
- 32-bit floating-point
- 32-bit signed integer
- 16-bit signed integer
- 8-bit unsigned integer
- MP3
- FLAC
- OGG/VORBIS
This function can handle all the codecs that underlying libsox can handle, however note the
followings.
Note:
Currently torchaudio's binary release does not include codecs library required to handle
OGG/VORBIS. To use these formats, you need to build torchaudio from source with
libsox and codecs libraries. Refer to README for this.
Note 1:
Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS.
If you need other formats, you need to build torchaudio from source with libsox and
the corresponding codecs. Refer to README for this.
Note 2:
This function is tested on the following formats;
- WAV
- 32-bit floating-point
- 32-bit signed integer
- 16-bit signed integer
- 8-bit unsigned integer
- MP3
- FLAC
- OGG/VORBIS
By default, this function returns Tensor with ``float32`` dtype and the shape of ``[channel, time]``.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
Expand All @@ -53,11 +57,13 @@ def load(
Args:
filepath: Path to audio file
frame_offset: Number of frames to skip before start reading data.
num_frames: Maximum number of frames to read. If there is not enough frames in
the given audio, this function does NOT raise an error.
normalize: When True and input file is integer WAV, the resulting Tensor type
becomes ``float32`` and values are normalized to ``[-1.0, 1.0]``.
This argument has no effect for other formats.
num_frames: Maximum number of frames to read. -1 reads all the remaining samples, starting
from ``frame_offset``. This function may return the less number of frames if there is
not enough frames in the given file.
normalize: When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``. If input file is integer WAV, giving ``False`` will change
the resulting Tensor type to integer type. This argument has no effect for formats other
than integer WAV type.
channels_first: When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
Expand Down
14 changes: 6 additions & 8 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct SoxDescriptor {

} // namespace

c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(const std::string& path) {
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
SoxDescriptor sd(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
Expand All @@ -46,15 +46,15 @@ c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(const std::string& path) {
throw std::runtime_error("Error opening audio file");
}

return c10::make_intrusive<::torchaudio::SignalInfo>(
return c10::make_intrusive<torchaudio::SignalInfo>(
static_cast<int64_t>(sd->signal.rate),
static_cast<int64_t>(sd->signal.channels),
static_cast<int64_t>(sd->signal.length / sd->signal.channels));
}

torch::Tensor load_audio_file(
const std::string& path,
c10::intrusive_ptr<::torchaudio::SignalInfo> info,
c10::intrusive_ptr<torchaudio::SignalInfo> info,
const int64_t frame_offset,
const int64_t num_frames,
const bool normalize,
Expand All @@ -80,14 +80,12 @@ torch::Tensor load_audio_file(
if (sd->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}

const int64_t num_total_samples = sd->signal.length;
const int64_t num_channels = sd->signal.channels;

if (num_total_samples == 0) {
if (sd->signal.length == 0) {
throw std::runtime_error("Error reading audio file: unkown length.");
}

const int64_t num_channels = sd->signal.channels;
const int64_t num_total_samples = sd->signal.length;
const int64_t sample_start = sd->signal.channels * frame_offset;

if (sox_seek(sd.get(), sample_start, 0) == SOX_EOF) {
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/sox_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
namespace torchaudio {
namespace sox_io {

c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path);

torch::Tensor load_audio_file(
const std::string& path,
c10::intrusive_ptr<::torchaudio::SignalInfo> info,
c10::intrusive_ptr<torchaudio::SignalInfo> info,
const int64_t frame_offset = 0,
const int64_t num_frames = -1,
const bool normalize = true,
Expand Down
17 changes: 8 additions & 9 deletions torchaudio/csrc/sox_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ caffe2::TypeMeta get_dtype(
const unsigned precision) {
const auto dtype = [&]() {
switch (encoding) {
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAVE
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
return torch::kUInt8;
case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAVE
case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
switch (precision) {
case 16:
return torch::kInt16;
Expand All @@ -22,14 +22,13 @@ caffe2::TypeMeta get_dtype(
throw std::runtime_error(
"Only 16 and 32 bits are supported for signed PCM.");
}
case SOX_ENCODING_FLOAT: // 32-bit floating-point WAVE
case SOX_ENCODING_MP3:
case SOX_ENCODING_FLAC:
case SOX_ENCODING_VORBIS:
case SOX_ENCODING_OPUS:
return torch::kFloat32;
default:
throw std::runtime_error("Unsupported encoding.");
// default to float32 for the other formats, including
// 32-bit flaoting-point WAV,
// MP3,
// FLAC,
// VORBIS etc...
return torch::kFloat32;
}
}();
return c10::scalarTypeToTypeMeta(dtype);
Expand Down

0 comments on commit 577dfcf

Please sign in to comment.