diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 58b4154..cababd5 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -32,6 +32,7 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" raise ValueError("J_phi must be greater or equal to J") self.T = [T*self.Q[j] for j in range(J)] self.in_channels = in_channels + self.padding_mode = padding_mode down = [] conv1d = [] self.dtcwt = murenn.DTCWT( @@ -88,6 +89,7 @@ def forward(self, x): UWx = torch.cat(UWx, dim=2) return UWx + def to_conv1d(self): """ Compute the single-resolution equivalent impulse response of the MuReNN layer. diff --git a/murenn/dtcwt/transform1d.py b/murenn/dtcwt/transform1d.py index d179fe8..d11d5f1 100644 --- a/murenn/dtcwt/transform1d.py +++ b/murenn/dtcwt/transform1d.py @@ -1,6 +1,7 @@ import numpy as np import dtcwt import torch.nn +import bisect from murenn.dtcwt.lowlevel import prep_filt from murenn.dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS, INV_J1, INV_J2PLUS @@ -188,6 +189,72 @@ def forward(self, x): return yl, yh + @property + def subbands(self): + """ + Return the subbands boundaries. + """ + + N = 2 ** (self.J + 4) + x = torch.zeros(1, 1, N) + x[0, 0, N//2] = 1 + + idtcwt = DTCWTInverse( + J = self.J, + alternate_gh=self.alternate_gh, + ) + # Compute the DTCWT of the impulse signal + x_phi, x_psis = self(x) + ys = [] + + for j in range(self.J): + y_phi = x_phi * 0 + y_psis = [x_psis[k] * (j==k) for k in range(self.J)] + y_j_hat = torch.abs(torch.fft.fft(idtcwt(y_phi, y_psis).squeeze())) + ys.append(y_j_hat) + + lp_psis = [x_psis[k] * 0 for k in range(self.J)] + y_lp_hat = torch.abs(torch.fft.fft(idtcwt(x_phi, lp_psis).squeeze())) + ys.append(y_lp_hat) + + # Stack tensors to create a 2D tensor where each row is a tensor from the list + ys = torch.stack(ys)[:, :N//2] + # Define the threshold + threshold = 0.2 + # Apply the threshold + valid_mask = ys >= threshold + ys = ys * valid_mask.float() + # Find the subbands of each frequency + max_values, max_indices = torch.max(ys, dim=0) + # Find the boundaries of the subbands + boundaries = torch.where(max_indices[:-1] != max_indices[1:])[0] + 1 + boundaries = boundaries / N + boundaries = torch.cat((torch.tensor([0.]), boundaries, torch.tensor([0.5]))).flip(dims=(0,)) + return boundaries.tolist() + + + def hz_to_octs(self, frequencies, sr=1.0): + """ + Convert a list of frequencies to their corresponding octave subband indices. + + Parameters: + frequencies (list of float): List of frequencies to convert. + sr (float): Sampling rate, default is 1.0. + + Returns: + list of int: List of octave subband indices corresponding to the input frequencies + -1 indicates out of range. + """ + subbands = [boundary * sr for boundary in self.subbands] + subbands.reverse() + js = [] + for freq in frequencies: + i = bisect.bisect_left(subbands, freq) + j = len(subbands) - i - 1 if i > 0 else -1 + js.append(j) + return js + + class DTCWTInverse(DTCWT): """Performs a DTCWT reconstruction of a sequence of 1-D signals. DTCWTInverse should be initialized in the same manner as DTCWTDirect. diff --git a/tests/test_dtcwt.py b/tests/test_dtcwt.py index ded75f2..68074a6 100644 --- a/tests/test_dtcwt.py +++ b/tests/test_dtcwt.py @@ -134,3 +134,30 @@ def test_avrg_energy(alternate_gh): P_Ux = P_Ux + Ppsi_j ratio = P_Ux / P_x assert torch.abs(ratio - 1) <= 0.01 + + +@pytest.mark.parametrize("J", list(range(1, 10))) +def test_subbands(J): + tfm = murenn.DTCWT(J=J) + subbands = tfm.subbands + # Test the number of subbands + # There are J band-pass subbands and 1 low-pass subband, so J+2 subbands boundaries in total. + assert len(subbands) == J + 2 + # Test the min/max value + assert min(subbands) == 0. + assert max(subbands) == 0.5 + # Check that it's sorted + assert all(subbands[i] > subbands[i+1] for i in range(len(subbands)-1) + ) + +@pytest.mark.parametrize("J", list(range(1, 10))) +def test_hz_to_octs(J): + sr = 16000 + nyquist = 8000 + dtcwt = murenn.DTCWT(J = J) + # Test with a very small frequency, expecting it to map to the highest subband index + assert dtcwt.hz_to_octs([1e-5], sr) == [J] + # Test with a frequency just above the Nyquist frequency, expecting it to map to -1 (out of range) + assert dtcwt.hz_to_octs([nyquist+1e-5], sr) == [-1] + # Test with a frequency just below the Nyquist frequency, expecting it to map to the lowest subband index + assert dtcwt.hz_to_octs([nyquist-1e-5], sr) == [0] \ No newline at end of file