Skip to content

Commit

Permalink
ENH subbands property (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n authored Aug 23, 2024
1 parent e1b890e commit 36464dd
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
2 changes: 2 additions & 0 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"
raise ValueError("J_phi must be greater or equal to J")
self.T = [T*self.Q[j] for j in range(J)]
self.in_channels = in_channels
self.padding_mode = padding_mode
down = []
conv1d = []
self.dtcwt = murenn.DTCWT(
Expand Down Expand Up @@ -88,6 +89,7 @@ def forward(self, x):
UWx = torch.cat(UWx, dim=2)
return UWx


def to_conv1d(self):
"""
Compute the single-resolution equivalent impulse response of the MuReNN layer.
Expand Down
67 changes: 67 additions & 0 deletions murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import dtcwt
import torch.nn
import bisect

from murenn.dtcwt.lowlevel import prep_filt
from murenn.dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS, INV_J1, INV_J2PLUS
Expand Down Expand Up @@ -188,6 +189,72 @@ def forward(self, x):
return yl, yh


@property
def subbands(self):
"""
Return the subbands boundaries.
"""

N = 2 ** (self.J + 4)
x = torch.zeros(1, 1, N)
x[0, 0, N//2] = 1

idtcwt = DTCWTInverse(
J = self.J,
alternate_gh=self.alternate_gh,
)
# Compute the DTCWT of the impulse signal
x_phi, x_psis = self(x)
ys = []

for j in range(self.J):
y_phi = x_phi * 0
y_psis = [x_psis[k] * (j==k) for k in range(self.J)]
y_j_hat = torch.abs(torch.fft.fft(idtcwt(y_phi, y_psis).squeeze()))
ys.append(y_j_hat)

lp_psis = [x_psis[k] * 0 for k in range(self.J)]
y_lp_hat = torch.abs(torch.fft.fft(idtcwt(x_phi, lp_psis).squeeze()))
ys.append(y_lp_hat)

# Stack tensors to create a 2D tensor where each row is a tensor from the list
ys = torch.stack(ys)[:, :N//2]
# Define the threshold
threshold = 0.2
# Apply the threshold
valid_mask = ys >= threshold
ys = ys * valid_mask.float()
# Find the subbands of each frequency
max_values, max_indices = torch.max(ys, dim=0)
# Find the boundaries of the subbands
boundaries = torch.where(max_indices[:-1] != max_indices[1:])[0] + 1
boundaries = boundaries / N
boundaries = torch.cat((torch.tensor([0.]), boundaries, torch.tensor([0.5]))).flip(dims=(0,))
return boundaries.tolist()


def hz_to_octs(self, frequencies, sr=1.0):
"""
Convert a list of frequencies to their corresponding octave subband indices.
Parameters:
frequencies (list of float): List of frequencies to convert.
sr (float): Sampling rate, default is 1.0.
Returns:
list of int: List of octave subband indices corresponding to the input frequencies
-1 indicates out of range.
"""
subbands = [boundary * sr for boundary in self.subbands]
subbands.reverse()
js = []
for freq in frequencies:
i = bisect.bisect_left(subbands, freq)
j = len(subbands) - i - 1 if i > 0 else -1
js.append(j)
return js


class DTCWTInverse(DTCWT):
"""Performs a DTCWT reconstruction of a sequence of 1-D signals. DTCWTInverse
should be initialized in the same manner as DTCWTDirect.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def test_avrg_energy(alternate_gh):
P_Ux = P_Ux + Ppsi_j
ratio = P_Ux / P_x
assert torch.abs(ratio - 1) <= 0.01


@pytest.mark.parametrize("J", list(range(1, 10)))
def test_subbands(J):
tfm = murenn.DTCWT(J=J)
subbands = tfm.subbands
# Test the number of subbands
# There are J band-pass subbands and 1 low-pass subband, so J+2 subbands boundaries in total.
assert len(subbands) == J + 2
# Test the min/max value
assert min(subbands) == 0.
assert max(subbands) == 0.5
# Check that it's sorted
assert all(subbands[i] > subbands[i+1] for i in range(len(subbands)-1)
)

@pytest.mark.parametrize("J", list(range(1, 10)))
def test_hz_to_octs(J):
sr = 16000
nyquist = 8000
dtcwt = murenn.DTCWT(J = J)
# Test with a very small frequency, expecting it to map to the highest subband index
assert dtcwt.hz_to_octs([1e-5], sr) == [J]
# Test with a frequency just above the Nyquist frequency, expecting it to map to -1 (out of range)
assert dtcwt.hz_to_octs([nyquist+1e-5], sr) == [-1]
# Test with a frequency just below the Nyquist frequency, expecting it to map to the lowest subband index
assert dtcwt.hz_to_octs([nyquist-1e-5], sr) == [0]

0 comments on commit 36464dd

Please sign in to comment.