Skip to content

Commit

Permalink
Merge pull request #47 from descriptinc/ps/crosstalk_tfm
Browse files Browse the repository at this point in the history
Adding CrossTalk, making Choose better.
  • Loading branch information
pseeth authored Aug 9, 2022
2 parents d68ee19 + 2d0d8f7 commit 3bb0fc3
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 31 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.7"
__version__ = "0.3.8"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
2 changes: 1 addition & 1 deletion audiotools/core/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def decompose_ir(self):
for idx in range(self.batch_size):
window_idx = early_idx[idx, 0].nonzero()
window[idx, ..., window_idx] = self.get_window(
"hanning", window_idx.shape[-1], self.device
"hann", window_idx.shape[-1], self.device
)
return early_response, late_field, window

Expand Down
62 changes: 42 additions & 20 deletions audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,11 @@ def __iter__(self):
class Choose(Compose):
# Class logic is the same as Compose, but instead of applying all
# the transforms in sequence, it applies just a single transform,
# which is picked deterministically by summing all of the `seed`
# integers (which could be just one or a batch of integers), and then
# using the sum as a seed to build a RandomState object that it then
# calls `choice` on, with probabilities `self.weights``.
# which is chosen for each item in the batch.
def __init__(
self,
*transforms: list,
weights: list = None,
max_seed: int = 1000,
name: str = None,
prob: float = 1.0,
):
Expand All @@ -186,19 +182,16 @@ def __init__(
_len = len(self.transforms)
weights = [1 / _len for _ in range(_len)]
self.weights = np.array(weights)
self.max_seed = max_seed

def _transform(self, signal, seed, **kwargs):
state = seed.sum().item()
state = util.random_state(state)
idx = list(range(len(self.transforms)))
idx = state.choice(idx, p=self.weights)
return self.transforms[idx](signal, **kwargs)

def _instantiate(self, state: RandomState, signal: AudioSignal = None):
parameters = super()._instantiate(state, signal)
parameters["seed"] = state.randint(self.max_seed)
return parameters
kwargs = super()._instantiate(state, signal)
tfm_idx = list(range(len(self.transforms)))
tfm_idx = state.choice(tfm_idx, p=self.weights)
for i, t in enumerate(self.transforms):
mask = kwargs[t.name]["mask"]
if mask.item():
kwargs[t.name]["mask"] = tt(i == tfm_idx)
return kwargs


class Repeat(Compose):
Expand All @@ -225,16 +218,13 @@ def __init__(
transform,
max_repeat: int = 5,
weights: list = None,
max_seed: int = 1000,
name: str = None,
prob: float = 1.0,
):
transforms = []
for n in range(1, max_repeat):
transforms.append(Repeat(transform, n_repeat=n))
super().__init__(
transforms, name=name, prob=prob, weights=weights, max_seed=max_seed
)
super().__init__(transforms, name=name, prob=prob, weights=weights)

self.max_repeat = max_repeat

Expand Down Expand Up @@ -366,6 +356,38 @@ def _transform(self, signal, bg_signal, snr, eq):
return signal.mix(bg_signal.clone(), snr, eq)


class CrossTalk(BaseTransform):
def __init__(
self,
snr: tuple = ("uniform", -5.0, 5.0),
max_seed: int = 1000,
name: str = None,
prob: float = 1.0,
):
"""
min and max refer to SNR.
"""
super().__init__(name=name, prob=prob)

self.snr = snr
self.max_seed = max_seed

def _instantiate(self, state: RandomState):
snr = util.sample_from_dist(self.snr, state)
seed = state.randint(self.max_seed)
return {"snr": snr, "seed": seed}

def _transform(self, signal, snr, seed):
state = seed.sum().item()
state = util.random_state(state)

idx = np.arange(signal.batch_size)
state.shuffle(idx)
input_loudness = signal.loudness()
mix = signal.mix(signal[idx.tolist()], snr)
return mix.normalize(input_loudness)


class RoomImpulseResponse(BaseTransform):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.3.7",
version="0.3.8",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/spk.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
path
tests/audio/spk/f10_script4_produced.wav
path,loudness
tests/audio/spk/f10_script4_produced.wav,-16
4 changes: 2 additions & 2 deletions tests/core/test_audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_device():

@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hanning", None])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
def test_stft(window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_log_magnitude():
@pytest.mark.parametrize("n_mels", [40, 80, 128])
@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hanning", None])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
def test_mel_spectrogram(n_mels, window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_low_pass():
f = 440
t = torch.arange(0, 1, 1 / sample_rate)
sine_wave = torch.sin(2 * np.pi * f * t)
window = AudioSignal.get_window("hanning", sine_wave.shape[-1], sine_wave.device)
window = AudioSignal.get_window("hann", sine_wave.shape[-1], sine_wave.device)
sine_wave = sine_wave * window
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
out = signal.deepcopy().low_pass(220)
Expand All @@ -102,7 +102,7 @@ def test_high_pass():
f = 440
t = torch.arange(0, 1, 1 / sample_rate)
sine_wave = torch.sin(2 * np.pi * f * t)
window = AudioSignal.get_window("hanning", sine_wave.shape[-1], sine_wave.device)
window = AudioSignal.get_window("hann", sine_wave.shape[-1], sine_wave.device)
sine_wave = sine_wave * window
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
out = signal.deepcopy().high_pass(220)
Expand Down
2 changes: 1 addition & 1 deletion tests/profilers/profile_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
transforms_to_demo = []
for x in dir(tfm):
if hasattr(getattr(tfm, x), "transform"):
if x not in ["Compose", "Choose"]:
if x not in ["Compose", "Choose", "Repeat", "RepeatUpTo"]:
transforms_to_demo.append(x)


Expand Down
3 changes: 3 additions & 0 deletions tests/regression/transforms/CrossTalk.wav
Git LFS file not shown
2 changes: 1 addition & 1 deletion tests/regression/transforms/RepeatUpTo.wav
Git LFS file not shown

0 comments on commit 3bb0fc3

Please sign in to comment.