-
Notifications
You must be signed in to change notification settings - Fork 670
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
772 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,90 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import scipy.io.wavfile | ||
|
||
|
||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' | ||
|
||
|
||
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype == torch.float32: | ||
pass | ||
elif tensor.dtype == torch.int32: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 2147483647. | ||
tensor[tensor < 0] /= 2147483648. | ||
elif tensor.dtype == torch.int16: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 32767. | ||
tensor[tensor < 0] /= 32768. | ||
elif tensor.dtype == torch.uint8: | ||
tensor = tensor.to(torch.float32) - 128 | ||
tensor[tensor > 0] /= 127. | ||
tensor[tensor < 0] /= 128. | ||
return tensor | ||
|
||
|
||
def get_wav_data( | ||
dtype: str, | ||
num_channels: int, | ||
*, | ||
num_frames: Optional[int] = None, | ||
normalize: bool = True, | ||
channels_first: bool = True, | ||
): | ||
"""Generate linear signal of the given dtype and num_channels | ||
Data range is | ||
[-1.0, 1.0] for float32, | ||
[-2147483648, 2147483647] for int32 | ||
[-32768, 32767] for int16 | ||
[0, 255] for uint8 | ||
num_frames allow to change the linear interpolation parameter. | ||
Default values are 256 for uint8, else 1 << 16. | ||
1 << 16 as default is so that int16 value range is completely covered. | ||
""" | ||
dtype_ = getattr(torch, dtype) | ||
|
||
if num_frames is None: | ||
if dtype == 'uint8': | ||
num_frames = 256 | ||
else: | ||
num_frames = 1 << 16 | ||
|
||
if dtype == 'uint8': | ||
base = torch.linspace(0, 255, num_frames, dtype=dtype_) | ||
if dtype == 'float32': | ||
base = torch.linspace(-1., 1., num_frames, dtype=dtype_) | ||
if dtype == 'int32': | ||
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) | ||
if dtype == 'int16': | ||
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) | ||
data = base.repeat([num_channels, 1]) | ||
if not channels_first: | ||
data = data.transpose(1, 0) | ||
if normalize: | ||
data = normalize_wav(data) | ||
return data | ||
|
||
|
||
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: | ||
"""Load wav file without torchaudio""" | ||
sample_rate, data = scipy.io.wavfile.read(path) | ||
data = torch.from_numpy(data.copy()) | ||
if data.ndim == 1: | ||
data = data.unsqueeze(1) | ||
if normalize: | ||
data = normalize_wav(data) | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
return data, sample_rate | ||
|
||
|
||
def save_wav(path, data, sample_rate, channels_first=True): | ||
"""Save wav file without torchaudio""" | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
scipy.io.wavfile.write(path, sample_rate, data.numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.