From 5cb3c2b4a644669acb8a7999cbf63ecf15f1cbec Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 10 Jun 2024 12:31:22 +0200 Subject: [PATCH 01/18] MuReNNDirect.to_conv1d --- murenn/dtcwt/nn.py | 56 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 2bb8771..76bd113 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -51,10 +51,11 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"): def forward(self, x): """ Args: - x (PyTorch tensor): A tensor of shape `(B, C, T)`. B is a batch size, - C denotes a number of channels, T is a length of signal sequence. + x (PyTorch tensor): A tensor of shape `(B, in_channels, T)`. B is a batch size. + in_channels is the number of channels in the input tensor, this should match + the in_channels attribute of the class instance. T is a length of signal sequence. Returns: - y (PyTorch tensor): A tensor of shape `(B, C, Q, J, T_out)` + y (PyTorch tensor): A tensor of shape `(B, in_channels, Q, J, T_out)` """ assert self.C == x.shape[1] lp, bps = self.dtcwt(x) @@ -69,6 +70,55 @@ def forward(self, x): UWx_j = UWx_j.view(B, self.C, self.Q, N) output.append(UWx_j) return torch.stack(output, dim=3) + + @property + def to_conv1d(self): + """ + Get per channel, per filter, per scale impulse responses + Returns: + y (PyTorch tensor): A tensor of shape `(B, C*Q*J, (2**J)*Q)` + """ + # T the filter length + T = self.conv1d[0].kernel_size[0] + # J the number of levels of decompostion + J = self.dtcwt.J + # Generate the impulse signal, this signal is zero padded to a length of (2**J)*T + N = 2**J * T + x = torch.zeros(1, self.C, N) + x[:, :, N//2] = 1 + # Get the padding mode + padding_mode = self.dtcwt.padding_mode + if padding_mode == "constant": + padding_mode = "zeros" + + inv = murenn.IDTCWT( + J=J, + padding_mode=padding_mode, + ) + # Get dtcwt's impulse reponse at each scale + phi, psis = self.dtcwt(x) + # Set phi to a zero valued tensor + zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) + # Create an empty list for {w_jq} + ys = [] + for j in range(J): + # Wpsi_jr = Re[psi_j] * w_jq + Wpsi_jr = self.conv1d[j](psis[j].real) + # W_ji = Im[psi_j] * w_jq + Wpsi_ji = self.conv1d[j](psis[j].imag) + # Set the bp coefficients besides this scale to zero + Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=(1, self.C*self.Q, psis[k].shape[-1])) for k in range(J)] + Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=(1, self.C*self.Q, psis[k].shape[-1])) for k in range(J)] + # Get the impulse response + y_jr = inv(zeros_phi, Wpsis_jr) + y_ji = inv(zeros_phi, Wpsis_ji) + y_j = torch.complex(y_jr, y_ji) + ys.append(y_j) + ys = torch.cat(ys, dim=1) + return ys + + + class ModulusStable(torch.autograd.Function): From 03b910b4c29bf8cc16fd3651363c3e7ec9dc6dec Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 10 Jun 2024 16:43:14 +0200 Subject: [PATCH 02/18] plot spectrums demo --- docs/examples/plot_conv1d_spectrum.py | 53 +++++++++++++++++++++++++++ murenn/dtcwt/nn.py | 32 ++++++++-------- 2 files changed, 68 insertions(+), 17 deletions(-) create mode 100644 docs/examples/plot_conv1d_spectrum.py diff --git a/docs/examples/plot_conv1d_spectrum.py b/docs/examples/plot_conv1d_spectrum.py new file mode 100644 index 0000000..fe11801 --- /dev/null +++ b/docs/examples/plot_conv1d_spectrum.py @@ -0,0 +1,53 @@ +# coding: utf-8 +""" +========================= +Visualize Amplitude Spectrums +========================= + +This notebook demonstrates how to visualize the amplitude spectrums +of a 1-D convolutional layer (Conv1D). +""" + +######################### +# Standard imports +import torch +from matplotlib import pyplot as plt +import numpy as np + +import murenn +######################################### +# number of scales +J = 8 + +# number of conv1d filters per scale +Q = 10 + +# kernel size +T = 10 + +# input channel +C = 1 + +# output signal length +N = 2 ** J * T + +# MuReNN +tfm = murenn.MuReNNDirect(J=J, Q=Q, T=T, in_channels=C) +w = tfm.to_conv1d.detach() +w = w.view(1, J, Q, -1) + +################################# +# Plot the spectrum. +colors = [ + 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', + 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] +plt.figure(figsize=(10, 3)) +for q in range(Q): + for j in range(J): + w_hat = torch.fft.fft(w[0, j, q, :]) + plt.semilogx(torch.abs(w_hat), color=colors[j]) +plt.grid(linestyle='--', alpha=0.5) +plt.xlim(0, N//2) +plt.xlabel("Frequency") +plt.title(f'murenn v{murenn.__version__}. Amplitude Spectrums') +plt.show() \ No newline at end of file diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 76bd113..b1dd254 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -74,9 +74,10 @@ def forward(self, x): @property def to_conv1d(self): """ - Get per channel, per filter, per scale impulse responses - Returns: - y (PyTorch tensor): A tensor of shape `(B, C*Q*J, (2**J)*Q)` + Get the per scale, per filter, per input channel impulse responses. + ------- + Return: + y (PyTorch tensor): A complex-valued tensor of shape `(B, in_channels*J*Q, (2**J)*Q)` """ # T the filter length T = self.conv1d[0].kernel_size[0] @@ -95,30 +96,27 @@ def to_conv1d(self): J=J, padding_mode=padding_mode, ) - # Get dtcwt's impulse reponse at each scale + # Get DTCWT impulse reponses phi, psis = self.dtcwt(x) # Set phi to a zero valued tensor zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) # Create an empty list for {w_jq} - ys = [] + ws = [] for j in range(J): # Wpsi_jr = Re[psi_j] * w_jq Wpsi_jr = self.conv1d[j](psis[j].real) # W_ji = Im[psi_j] * w_jq Wpsi_ji = self.conv1d[j](psis[j].imag) - # Set the bp coefficients besides this scale to zero - Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=(1, self.C*self.Q, psis[k].shape[-1])) for k in range(J)] - Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=(1, self.C*self.Q, psis[k].shape[-1])) for k in range(J)] + # Set the coefficients besides this scale to zero .repeat(ch, 1, 1) + Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] + Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] # Get the impulse response - y_jr = inv(zeros_phi, Wpsis_jr) - y_ji = inv(zeros_phi, Wpsis_ji) - y_j = torch.complex(y_jr, y_ji) - ys.append(y_j) - ys = torch.cat(ys, dim=1) - return ys - - - + w_jr = inv(zeros_phi, Wpsis_jr) + w_ji = inv(zeros_phi, Wpsis_ji) + w_j = torch.complex(w_jr, w_ji) + ws.append(w_j) + ws = torch.cat(ws, dim=1) + return ws class ModulusStable(torch.autograd.Function): From 025bf7b2e75a2f07ff7f70abda43763463c5c82a Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 10 Jun 2024 16:49:26 +0200 Subject: [PATCH 03/18] small changes --- docs/examples/plot_conv1d_spectrum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/plot_conv1d_spectrum.py b/docs/examples/plot_conv1d_spectrum.py index fe11801..dc0e590 100644 --- a/docs/examples/plot_conv1d_spectrum.py +++ b/docs/examples/plot_conv1d_spectrum.py @@ -36,8 +36,8 @@ w = tfm.to_conv1d.detach() w = w.view(1, J, Q, -1) -################################# -# Plot the spectrum. +######################################################################### +# Plot the spectrums per scale per filter. colors = [ 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] From 710d258a9e31c9144ebc3f8a198bb859800c3055 Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 10 Jun 2024 16:57:20 +0200 Subject: [PATCH 04/18] test --- tests/test_nn.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_nn.py b/tests/test_nn.py index 69d3620..e308d6c 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -86,4 +86,18 @@ def test_modulus(): loss0 = torch.sum(Ux0) loss0.backward() assert torch.max(torch.abs(x0r.grad)) <= 1e-7 - assert torch.max(torch.abs(x0i.grad)) <= 1e-7 \ No newline at end of file + assert torch.max(torch.abs(x0i.grad)) <= 1e-7 + +@pytest.mark.parametrize("Q", [1, 2]) +@pytest.mark.parametrize("T", [1, 2]) +@pytest.mark.parametrize("C", [1, 2]) +def test_toconv1d_shape(Q, T, C): + J = 4 + tfm = murenn.MuReNNDirect( + J=J, + Q=Q, + T=T, + in_channels=C, + ) + y = tfm.to_conv1d + assert y.shape == (1, Q*J*C, 2**J*T) \ No newline at end of file From de87974350eb186828b5a1d3794295045ee0a850 Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 11 Jun 2024 18:40:51 +0200 Subject: [PATCH 05/18] fix normalization's bug --- murenn/dtcwt/transform1d.py | 9 ++++----- murenn/dtcwt/transform_funcs.py | 25 +++++-------------------- tests/test_grad.py | 14 ++++---------- 3 files changed, 13 insertions(+), 35 deletions(-) diff --git a/murenn/dtcwt/transform1d.py b/murenn/dtcwt/transform1d.py index b6424be..90c1f23 100644 --- a/murenn/dtcwt/transform1d.py +++ b/murenn/dtcwt/transform1d.py @@ -156,7 +156,8 @@ def forward(self, x): # Ensure the lowpass is divisible by 4 if x_phi.shape[-1] % 4 != 0: x_phi = torch.cat((x_phi[:,:,0:1], x_phi, x_phi[:,:,-1:]), dim=-1) - + if self.normalize: + x_phi = 1/np.sqrt(2) * x_phi x_phi, x_psi_r, x_psi_i = FWD_J2PLUS.apply( x_phi, h0a, @@ -165,15 +166,12 @@ def forward(self, x): h1b, self.skip_hps[j], self.padding_mode, - self.normalize, ) - if (j % 2 == 1) and self.alternate_gh: # The result is anti-analytic in the Hilbert sense. # We conjugate the result to bring the spectrum back to (0, pi). # This is purely by convention and for consistency through j. x_psi_i = -1 * x_psi_i - x_psis.append(x_psi_r + 1j * x_psi_i) if self.include_scale[j]: @@ -295,8 +293,9 @@ def forward(self, yl, yh): g0b, g1b, self.padding_mode, - self.normalize, ) + if self.normalize: + x_phi = np.sqrt(2) * x_phi ## LEVEL 1 ## x_psi_r, x_psi_i = x_psis[0].real, x_psis[0].imag diff --git a/murenn/dtcwt/transform_funcs.py b/murenn/dtcwt/transform_funcs.py index 0fa7bd4..9936b33 100644 --- a/murenn/dtcwt/transform_funcs.py +++ b/murenn/dtcwt/transform_funcs.py @@ -75,7 +75,7 @@ class FWD_J2PLUS(torch.autograd.Function): high-pass output of tree b.""" @staticmethod - def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize): + def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode): """ Forward dual-tree complex wavelet transform at levels 2 and coarser. @@ -104,7 +104,6 @@ def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize): ctx.save_for_backward(h0a_rep, h1a_rep, h0b_rep, h1b_rep) ctx.skip_hps = skip_hps ctx.mode = mode_to_int(padding_mode) - ctx.normalize = normalize # Apply low-pass filtering on trees a (real) and b (imaginary). lo = coldfilt(x_phi, h0a_rep, h0b_rep, padding_mode) @@ -122,17 +121,13 @@ def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize): # Return low-pass output, and band-pass output in conjunction: # real part for tree a and imaginary part for tree b. - if normalize: - return 1/np.sqrt(2) * lo, 1/np.sqrt(2) * bp_r, 1/np.sqrt(2) * bp_i - else: - return lo, bp_r, bp_i + return lo, bp_r, bp_i @staticmethod def backward(ctx, dx_phi, dx_psi_r, dx_psi_i): g0b, g1b, g0a, g1a = ctx.saved_tensors skip_hps = ctx.skip_hps padding_mode = int_to_mode(ctx.mode) - normalize = ctx.normalize b, ch, T = dx_phi.shape if not ctx.needs_input_grad[0]: dx = None @@ -141,8 +136,6 @@ def backward(ctx, dx_phi, dx_psi_r, dx_psi_i): if not skip_hps: dx_psi = torch.stack((dx_psi_i, dx_psi_r), dim=-1).view(b, ch, T) dx += colifilt(dx_psi, g1a, g1b, padding_mode) - if normalize: - dx *= 1/np.sqrt(2) return dx, None, None, None, None, None, None, None @@ -212,7 +205,7 @@ class INV_J2PLUS(torch.autograd.Function): """ @staticmethod - def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode, normalize): + def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode): """ Inverse dual-tree complex wavelet transform at levels 2 and coarser. @@ -237,33 +230,25 @@ def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode, normalize): g0b_rep = g0b.repeat(ch, 1, 1) g1b_rep = g1b.repeat(ch, 1, 1) ctx.save_for_backward(g0a_rep, g1a_rep, g0b_rep, g1b_rep) - ctx.normalize = normalize ctx.mode = mode_to_int(padding_mode) bp = torch.stack((bp_i, bp_r), dim=-1).view(b, ch, T) lo = colifilt(lo, g0a_rep, g0b_rep, padding_mode) + colifilt(bp, g1a_rep, g1b_rep, padding_mode) - if normalize: - return np.sqrt(2) * lo - else: - return lo + return lo + @staticmethod def backward(ctx, dx): g0b, g1b, g0a, g1a = ctx.saved_tensors padding_mode = int_to_mode(ctx.mode) - normalize = ctx.normalize b, ch, T = dx.shape dlo, dbp = None, None if ctx.needs_input_grad[0]: dlo = coldfilt(dx, g0a, g0b, padding_mode) dlo = torch.stack([dlo[:,:ch], dlo[:,ch:2*ch]], dim=-1).view(b, ch, T//2) - if normalize: - dlo *= np.sqrt(2) if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: dbp = coldfilt(dx, g1a, g1b, padding_mode) - if normalize: - dbp *= np.sqrt(2) if ctx.needs_input_grad[1]: dbp_r = dbp[:,ch:2*ch] if ctx.needs_input_grad[2]: diff --git a/tests/test_grad.py b/tests/test_grad.py index c6af760..9af56a5 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -1,8 +1,6 @@ import pytest import torch from torch.autograd import gradcheck -import numpy as np -import dtcwt import murenn import murenn.dtcwt.transform_funcs as tf from contextlib import contextmanager @@ -36,15 +34,14 @@ def test_fwd_j1(skip_hps): gradcheck(tf.FWD_J1.apply, input, eps=eps, atol=atol) -@pytest.mark.parametrize("normalize", [True, False]) @pytest.mark.parametrize("skip_hps", [[0, 1], [1, 0]]) -def test_fwd_j2(skip_hps, normalize): +def test_fwd_j2(skip_hps): J = 2 eps = 1e-3 atol = 1e-4 with set_double_precision(): x = torch.randn(2, 2, 4, device=dev, requires_grad=True) - fwd = murenn.DTCWTDirect(J=J, skip_hps=skip_hps, normalize=normalize).to(dev) + fwd = murenn.DTCWTDirect(J=J, skip_hps=skip_hps).to(dev) input = ( x, fwd.h0a, @@ -53,7 +50,6 @@ def test_fwd_j2(skip_hps, normalize): fwd.h1b, fwd.skip_hps[1], fwd.padding_mode, - fwd.normalize, ) gradcheck(tf.FWD_J2PLUS.apply, input, eps=eps, atol=atol) @@ -71,8 +67,7 @@ def test_inv_j1(): gradcheck(tf.INV_J1.apply, input, eps=eps, atol=atol) -@pytest.mark.parametrize("normalize", [True, False]) -def test_inv_j2(normalize): +def test_inv_j2(): J = 2 eps = 1e-3 atol = 1e-4 @@ -80,7 +75,7 @@ def test_inv_j2(normalize): lo = torch.randn(2, 2, 8, device=dev, requires_grad=True) bp_r = torch.randn(2, 2, 4, device=dev, requires_grad=True) bp_i = torch.randn(2, 2, 4, device=dev, requires_grad=True) - inv = murenn.DTCWTInverse(J=J, normalize=normalize).to(dev) + inv = murenn.DTCWTInverse(J=J).to(dev) input = ( lo, @@ -91,7 +86,6 @@ def test_inv_j2(normalize): inv.g0b, inv.g1b, inv.padding_mode, - inv.normalize, ) gradcheck(tf.INV_J2PLUS.apply, input, eps=eps, atol=atol) From 8af592e8de90ac7d0413057b7d448ff777e64571 Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 11 Jun 2024 19:36:14 +0200 Subject: [PATCH 06/18] rename --- .../{plot_conv1d_spectrum.py => plot_frequency_response.py} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename docs/examples/{plot_conv1d_spectrum.py => plot_frequency_response.py} (86%) diff --git a/docs/examples/plot_conv1d_spectrum.py b/docs/examples/plot_frequency_response.py similarity index 86% rename from docs/examples/plot_conv1d_spectrum.py rename to docs/examples/plot_frequency_response.py index dc0e590..9542594 100644 --- a/docs/examples/plot_conv1d_spectrum.py +++ b/docs/examples/plot_frequency_response.py @@ -1,10 +1,10 @@ # coding: utf-8 """ ========================= -Visualize Amplitude Spectrums +Frequency Magnitude Response ========================= -This notebook demonstrates how to visualize the amplitude spectrums +This notebook demonstrates how to visualize the frequency response of a 1-D convolutional layer (Conv1D). """ @@ -49,5 +49,5 @@ plt.grid(linestyle='--', alpha=0.5) plt.xlim(0, N//2) plt.xlabel("Frequency") -plt.title(f'murenn v{murenn.__version__}. Amplitude Spectrums') +plt.title(f'murenn v{murenn.__version__}. Frequency Magnitude Response') plt.show() \ No newline at end of file From cb13c0a8aa8bbdc3044116a61ff55a507334da84 Mon Sep 17 00:00:00 2001 From: Xiran Date: Thu, 13 Jun 2024 13:37:40 +0200 Subject: [PATCH 07/18] Update nn.py --- murenn/dtcwt/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index b1dd254..57bc3a6 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -30,6 +30,7 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"): J=J-j, padding_mode=padding_mode, skip_hps=True, + normalize=False, ) down.append(down_j) From eb0003d9d78df71582acf16a28469622a49fd037 Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 17 Jun 2024 18:45:22 +0200 Subject: [PATCH 08/18] Update nn.py --- murenn/dtcwt/nn.py | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 57bc3a6..c6bf687 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -75,10 +75,30 @@ def forward(self, x): @property def to_conv1d(self): """ - Get the per scale, per filter, per input channel impulse responses. + Compute the single-resolution equivalent impulse response of the MuReNN layer. + This would be helpful for visualization in Fourier domain, for receptive fields, + and for comparing computational costs. + DTCWT conv1d IDTCWT + δ -------> ψ_j --------> w_jq -------> y_jq ------- Return: - y (PyTorch tensor): A complex-valued tensor of shape `(B, in_channels*J*Q, (2**J)*Q)` + conv1d (torch.nn.Conv1d): A Pytorch Conv1d instance with weights initialized to y_jq. + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> tfm = murenn.MuReNNDirect(J=8, Q=5, T=32, in_channels=1) + >>> conv1d = tfm.to_conv1d + >>> x = torch.zeros(1,1,2**10) + >>> x[0,0,N//2]=1 + >>> x = x*(1+0j) + >>> w = conv1d(x*(1+0j)).reshape(J,Q,-1).detach() + >>> colors = [ + >>> 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', + >>> 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] + >>> for j in range(J): + >>> for q in range(Q): + >>> plt.semilogx(torch.abs(torch.fft.fft(w[j,q,:])), color=colors[j]) + >>> plt.xlim(0, N//2) """ # T the filter length T = self.conv1d[0].kernel_size[0] @@ -100,7 +120,8 @@ def to_conv1d(self): # Get DTCWT impulse reponses phi, psis = self.dtcwt(x) # Set phi to a zero valued tensor - zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) + # zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) + zeros_phi = phi.new_zeros(size=(self.C, self.Q, phi.shape[-1])) # Create an empty list for {w_jq} ws = [] for j in range(J): @@ -108,16 +129,26 @@ def to_conv1d(self): Wpsi_jr = self.conv1d[j](psis[j].real) # W_ji = Im[psi_j] * w_jq Wpsi_ji = self.conv1d[j](psis[j].imag) - # Set the coefficients besides this scale to zero .repeat(ch, 1, 1) + # Set the coefficients besides this scale to zero Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] # Get the impulse response w_jr = inv(zeros_phi, Wpsis_jr) w_ji = inv(zeros_phi, Wpsis_ji) w_j = torch.complex(w_jr, w_ji) + # We only need data form one channel + w_j = w_j.reshape(self.C, self.Q, 1, N)[0,...] ws.append(w_j) - ws = torch.cat(ws, dim=1) - return ws + ws = torch.cat(ws, dim=0) # this tensor has a shape of J*Q, 1, N, + conv1d = torch.nn.Conv1d( + in_channels=1, + out_channels=J*self.Q, + kernel_size=N, + bias=False, + padding="same", + ) + conv1d.weight.data = torch.nn.parameter.Parameter(ws) + return conv1d class ModulusStable(torch.autograd.Function): From fafe691367748b377937c09d4ccce3c3659b4868 Mon Sep 17 00:00:00 2001 From: Xiran Date: Mon, 17 Jun 2024 18:53:10 +0200 Subject: [PATCH 09/18] update test --- docs/examples/plot_frequency_response.py | 53 ------------------------ tests/test_nn.py | 5 ++- 2 files changed, 3 insertions(+), 55 deletions(-) delete mode 100644 docs/examples/plot_frequency_response.py diff --git a/docs/examples/plot_frequency_response.py b/docs/examples/plot_frequency_response.py deleted file mode 100644 index 9542594..0000000 --- a/docs/examples/plot_frequency_response.py +++ /dev/null @@ -1,53 +0,0 @@ -# coding: utf-8 -""" -========================= -Frequency Magnitude Response -========================= - -This notebook demonstrates how to visualize the frequency response -of a 1-D convolutional layer (Conv1D). -""" - -######################### -# Standard imports -import torch -from matplotlib import pyplot as plt -import numpy as np - -import murenn -######################################### -# number of scales -J = 8 - -# number of conv1d filters per scale -Q = 10 - -# kernel size -T = 10 - -# input channel -C = 1 - -# output signal length -N = 2 ** J * T - -# MuReNN -tfm = murenn.MuReNNDirect(J=J, Q=Q, T=T, in_channels=C) -w = tfm.to_conv1d.detach() -w = w.view(1, J, Q, -1) - -######################################################################### -# Plot the spectrums per scale per filter. -colors = [ - 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', - 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] -plt.figure(figsize=(10, 3)) -for q in range(Q): - for j in range(J): - w_hat = torch.fft.fft(w[0, j, q, :]) - plt.semilogx(torch.abs(w_hat), color=colors[j]) -plt.grid(linestyle='--', alpha=0.5) -plt.xlim(0, N//2) -plt.xlabel("Frequency") -plt.title(f'murenn v{murenn.__version__}. Frequency Magnitude Response') -plt.show() \ No newline at end of file diff --git a/tests/test_nn.py b/tests/test_nn.py index e308d6c..21ec6a7 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -99,5 +99,6 @@ def test_toconv1d_shape(Q, T, C): T=T, in_channels=C, ) - y = tfm.to_conv1d - assert y.shape == (1, Q*J*C, 2**J*T) \ No newline at end of file + conv1d = tfm.to_conv1d + assert isinstance(conv1d, torch.nn.Conv1d) + assert conv1d.weight.data.shape == (J*Q, 1, 2**J*T) \ No newline at end of file From b5e49e680aa3f7299ecd127463bd7eb67fc165b7 Mon Sep 17 00:00:00 2001 From: Xiran Date: Thu, 20 Jun 2024 23:47:07 +0200 Subject: [PATCH 10/18] Update nn.py --- murenn/dtcwt/nn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index c6bf687..51214da 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -23,6 +23,7 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"): self.dtcwt = murenn.DTCWT( J=J, padding_mode=padding_mode, + normalize=False, ) for j in range(J): @@ -91,7 +92,7 @@ def to_conv1d(self): >>> x = torch.zeros(1,1,2**10) >>> x[0,0,N//2]=1 >>> x = x*(1+0j) - >>> w = conv1d(x*(1+0j)).reshape(J,Q,-1).detach() + >>> w = conv1d(x).reshape(J,Q,-1).detach() >>> colors = [ >>> 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', >>> 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] @@ -104,7 +105,7 @@ def to_conv1d(self): T = self.conv1d[0].kernel_size[0] # J the number of levels of decompostion J = self.dtcwt.J - # Generate the impulse signal, this signal is zero padded to a length of (2**J)*T + # Generate the impulse signal N = 2**J * T x = torch.zeros(1, self.C, N) x[:, :, N//2] = 1 @@ -116,12 +117,12 @@ def to_conv1d(self): inv = murenn.IDTCWT( J=J, padding_mode=padding_mode, + normalize=False, ) # Get DTCWT impulse reponses phi, psis = self.dtcwt(x) # Set phi to a zero valued tensor - # zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) - zeros_phi = phi.new_zeros(size=(self.C, self.Q, phi.shape[-1])) + zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) # Create an empty list for {w_jq} ws = [] for j in range(J): From fe0dda4aaa913665508b83c7a6020d87c4f64fe5 Mon Sep 17 00:00:00 2001 From: Xiran Date: Sat, 22 Jun 2024 12:28:20 +0200 Subject: [PATCH 11/18] Update test_dtcwt.py --- tests/test_dtcwt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_dtcwt.py b/tests/test_dtcwt.py index 873b5d3..78f6d15 100644 --- a/tests/test_dtcwt.py +++ b/tests/test_dtcwt.py @@ -39,7 +39,7 @@ def test_fwd_same(J): @pytest.mark.parametrize("normalize", [True, False]) @pytest.mark.parametrize("J", list(range(1, 5))) @pytest.mark.parametrize("T", [44099, 44100]) -def test_inv(level1, qshift, J, T, alternate_gh, normalize): +def test_pr(level1, qshift, J, T, alternate_gh, normalize): Xt = torch.randn(2, 2, T) xfm_murenn = murenn.DTCWTDirect( J=J, @@ -75,3 +75,17 @@ def test_skip_hps(skip_hps, include_scale): inv = murenn.DTCWTInverse(J=J, skip_hps=skip_hps, include_scale=include_scale) X_rec = inv(lp, bp) assert X_rec.shape == Xt.shape + +def test_inv(): + T = 2**10 + Xt = torch.randn(2, 2, T) + dtcwt = murenn.DTCWTDirect() + idtcwt = murenn.DTCWTInverse() + lp, bp = dtcwt(Xt) + lp = lp.new_zeros(lp.shape) + X_rec = idtcwt(lp, bp) + bp_r = [(bp[j].real)*(1+0j) for j in range(8)] + bp_i = [(bp[j].imag)*(0+1j) for j in range(8)] + X_rec_r = idtcwt(lp, bp_r) + X_rec_i = idtcwt(lp, bp_i) + assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3) \ No newline at end of file From d7b5744c954459219a7d1cb7ec99e418ebb6bfba Mon Sep 17 00:00:00 2001 From: Xiran Date: Sat, 22 Jun 2024 12:59:54 +0200 Subject: [PATCH 12/18] Update nn.py --- murenn/dtcwt/nn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 51214da..dcdf627 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -87,11 +87,14 @@ def to_conv1d(self): Examples -------- >>> import matplotlib.pyplot as plt + >>> J = 8 + >>> Q = 5 + >>> N = 2**10 >>> tfm = murenn.MuReNNDirect(J=8, Q=5, T=32, in_channels=1) >>> conv1d = tfm.to_conv1d - >>> x = torch.zeros(1,1,2**10) + >>> x = torch.zeros(1,1,N) >>> x[0,0,N//2]=1 - >>> x = x*(1+0j) + >>> x = x*(1-1j) >>> w = conv1d(x).reshape(J,Q,-1).detach() >>> colors = [ >>> 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', From dfeb404e5398f992a3783797599167adfc6369b5 Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 09:47:31 +0200 Subject: [PATCH 13/18] Squashed commit of the following: commit bb7241cd6e9668bd545bd4218ef60efb3a225746 Author: Xiran Date: Tue Jul 23 09:37:35 2024 +0200 enable gpu commit ef89866b7613d8669190e014105245164be04f0f Author: Xiran Date: Wed Jun 26 23:19:54 2024 +0200 FEAT nb of conv1d filters per scale commit d824ad9e6118f3d19b62212b69abbac5a6187b66 Author: Xiran Date: Sun Jun 23 23:11:50 2024 +0200 complex conv1d-->real conv1d --- murenn/dtcwt/nn.py | 116 ++++++++++++++++-------------------- murenn/dtcwt/transform1d.py | 5 +- tests/test_dtcwt.py | 4 +- tests/test_nn.py | 8 +-- 4 files changed, 61 insertions(+), 72 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index dcdf627..3623adb 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -9,26 +9,39 @@ class MuReNNDirect(torch.nn.Module): """ Args: J (int): Number of levels (octaves) in the DTCWT decomposition. - Q (int): Number of Conv1D filters per octave. + Q (int, dict or list): Number of Conv1D filters per octave. + T (int): Conv1D Kernel size multiplier + J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J. in_channels (int): Number of channels in the input signal. padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate', and 'circular'. Padding scheme for the DTCWT decomposition. """ - def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"): + def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"): super().__init__() - self.Q = Q - self.C = in_channels + if isinstance(Q, int): + self.Q = [Q for j in range(J)] + elif isinstance(Q, (dict, list)): + assert len(Q) == J + self.Q = Q + else: + raise TypeError(f"Q must to be int, dict or list, got {type(Q)}") + if J_phi is None: + self.J_phi = J + if J_phi < J: + raise ValueError("J_phi must be greater or equal to J") + self.T = [T*self.Q[j] for j in range(J)] + self.in_channels = in_channels down = [] conv1d = [] self.dtcwt = murenn.DTCWT( J=J, padding_mode=padding_mode, - normalize=False, + normalize=True, ) for j in range(J): down_j = murenn.DTCWT( - J=J-j, + J=J_phi-j, padding_mode=padding_mode, skip_hps=True, normalize=False, @@ -37,8 +50,8 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"): conv1d_j = torch.nn.Conv1d( in_channels=in_channels, - out_channels=Q*in_channels, - kernel_size=T, + out_channels=self.Q[j]*in_channels, + kernel_size=self.T[j], bias=False, groups=in_channels, padding="same", @@ -59,19 +72,22 @@ def forward(self, x): Returns: y (PyTorch tensor): A tensor of shape `(B, in_channels, Q, J, T_out)` """ - assert self.C == x.shape[1] + assert self.in_channels == x.shape[1] lp, bps = self.dtcwt(x) - output = [] + UWx = [] for j in range(self.dtcwt.J): Wx_j_r = self.conv1d[j](bps[j].real) Wx_j_i = self.conv1d[j](bps[j].imag) - UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i) + # UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i) + UWx_j = Wx_j_r**2 + Wx_j_i**2 + # Avarange over time UWx_j, _ = self.down[j](UWx_j) B, _, N = UWx_j.shape - # reshape from (B, C*Q, N) to (B, C, Q, N) - UWx_j = UWx_j.view(B, self.C, self.Q, N) - output.append(UWx_j) - return torch.stack(output, dim=3) + UWx_j = UWx_j.view(B, self.in_channels, self.Q[j], N) + UWx.append(UWx_j) + # Frequency range from high to low + UWx = torch.cat(UWx, dim=2) + return UWx @property def to_conv1d(self): @@ -84,69 +100,41 @@ def to_conv1d(self): ------- Return: conv1d (torch.nn.Conv1d): A Pytorch Conv1d instance with weights initialized to y_jq. - Examples - -------- - >>> import matplotlib.pyplot as plt - >>> J = 8 - >>> Q = 5 - >>> N = 2**10 - >>> tfm = murenn.MuReNNDirect(J=8, Q=5, T=32, in_channels=1) - >>> conv1d = tfm.to_conv1d - >>> x = torch.zeros(1,1,N) - >>> x[0,0,N//2]=1 - >>> x = x*(1-1j) - >>> w = conv1d(x).reshape(J,Q,-1).detach() - >>> colors = [ - >>> 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', - >>> 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] - >>> for j in range(J): - >>> for q in range(Q): - >>> plt.semilogx(torch.abs(torch.fft.fft(w[j,q,:])), color=colors[j]) - >>> plt.xlim(0, N//2) """ + + device = self.conv1d[0].weight.data.device # T the filter length - T = self.conv1d[0].kernel_size[0] + # T = self.conv1d[0].kernel_size[0] + T = max(self.T) # J the number of levels of decompostion J = self.dtcwt.J # Generate the impulse signal - N = 2**J * T - x = torch.zeros(1, self.C, N) + N = 2 ** J * T + x = torch.zeros(1, self.in_channels, N).to(device) x[:, :, N//2] = 1 - # Get the padding mode - padding_mode = self.dtcwt.padding_mode - if padding_mode == "constant": - padding_mode = "zeros" - inv = murenn.IDTCWT( - J=J, - padding_mode=padding_mode, - normalize=False, - ) + J = J, + normalize=True, + ).to(device) # Get DTCWT impulse reponses phi, psis = self.dtcwt(x) # Set phi to a zero valued tensor - zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1])) - # Create an empty list for {w_jq} + zeros_phi = phi.new_zeros(1,1,phi.shape[-1]) ws = [] for j in range(J): - # Wpsi_jr = Re[psi_j] * w_jq - Wpsi_jr = self.conv1d[j](psis[j].real) - # W_ji = Im[psi_j] * w_jq - Wpsi_ji = self.conv1d[j](psis[j].imag) - # Set the coefficients besides this scale to zero - Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] - Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] - # Get the impulse response - w_jr = inv(zeros_phi, Wpsis_jr) - w_ji = inv(zeros_phi, Wpsis_ji) - w_j = torch.complex(w_jr, w_ji) - # We only need data form one channel - w_j = w_j.reshape(self.C, self.Q, 1, N)[0,...] - ws.append(w_j) - ws = torch.cat(ws, dim=0) # this tensor has a shape of J*Q, 1, N, + Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1) + Wpsi_ji = self.conv1d[j](psis[j].imag).reshape(self.in_channels, self.Q[j], -1) + for q in range(self.Q[j]): + Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1) + Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1) + # Wpsis_j = [torch.complex(Wpsi_jr, Wpsi_ji) if k == j else psis[k].new_zeros(psis[k].shape).repeat(1, self.Q[k], 1) for k in range(J)] + Wpsis_jq = [torch.complex(Wpsi_jqr, Wpsi_jqi) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)] + w_jq = inv(zeros_phi, Wpsis_jq) + ws.append(w_jq) + ws = torch.cat(ws, dim=0) conv1d = torch.nn.Conv1d( in_channels=1, - out_channels=J*self.Q, + out_channels=ws.shape[0], kernel_size=N, bias=False, padding="same", diff --git a/murenn/dtcwt/transform1d.py b/murenn/dtcwt/transform1d.py index 90c1f23..54774e5 100644 --- a/murenn/dtcwt/transform1d.py +++ b/murenn/dtcwt/transform1d.py @@ -297,7 +297,10 @@ def forward(self, yl, yh): if self.normalize: x_phi = np.sqrt(2) * x_phi - ## LEVEL 1 ## + # LEVEL 1 ## + if x_phi.shape[-1] != x_psis[0].shape[-1] * 2: + x_phi = x_phi[:,:,1:-1] + x_psi_r, x_psi_i = x_psis[0].real, x_psis[0].imag x_phi = INV_J1.apply( diff --git a/tests/test_dtcwt.py b/tests/test_dtcwt.py index 78f6d15..7629c9b 100644 --- a/tests/test_dtcwt.py +++ b/tests/test_dtcwt.py @@ -84,8 +84,8 @@ def test_inv(): lp, bp = dtcwt(Xt) lp = lp.new_zeros(lp.shape) X_rec = idtcwt(lp, bp) - bp_r = [(bp[j].real)*(1+0j) for j in range(8)] - bp_i = [(bp[j].imag)*(0+1j) for j in range(8)] + bp_r = [(bp[j].real)*(1+0j) for j in range(dtcwt.J)] + bp_i = [(bp[j].imag)*(0+1j) for j in range(dtcwt.J)] X_rec_r = idtcwt(lp, bp_r) X_rec_i = idtcwt(lp, bp_i) assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3) \ No newline at end of file diff --git a/tests/test_nn.py b/tests/test_nn.py index 21ec6a7..ce1d396 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -90,15 +90,13 @@ def test_modulus(): @pytest.mark.parametrize("Q", [1, 2]) @pytest.mark.parametrize("T", [1, 2]) -@pytest.mark.parametrize("C", [1, 2]) -def test_toconv1d_shape(Q, T, C): +def test_toconv1d_shape(Q, T): J = 4 tfm = murenn.MuReNNDirect( J=J, Q=Q, T=T, - in_channels=C, + in_channels=2, ) conv1d = tfm.to_conv1d - assert isinstance(conv1d, torch.nn.Conv1d) - assert conv1d.weight.data.shape == (J*Q, 1, 2**J*T) \ No newline at end of file + assert isinstance(conv1d, torch.nn.Conv1d) \ No newline at end of file From fdf681102a8be91285244cd1cce2452b051f0b3d Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 14:16:25 +0200 Subject: [PATCH 14/18] fix bug --- murenn/dtcwt/nn.py | 24 +++++++++++------------- tests/test_nn.py | 2 +- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 3623adb..5805eb5 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -26,7 +26,7 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" else: raise TypeError(f"Q must to be int, dict or list, got {type(Q)}") if J_phi is None: - self.J_phi = J + J_phi = J if J_phi < J: raise ValueError("J_phi must be greater or equal to J") self.T = [T*self.Q[j] for j in range(J)] @@ -36,7 +36,7 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" self.dtcwt = murenn.DTCWT( J=J, padding_mode=padding_mode, - normalize=True, + alternate_gh=False, ) for j in range(J): @@ -44,7 +44,7 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" J=J_phi-j, padding_mode=padding_mode, skip_hps=True, - normalize=False, + alternate_gh=False, ) down.append(down_j) @@ -78,14 +78,12 @@ def forward(self, x): for j in range(self.dtcwt.J): Wx_j_r = self.conv1d[j](bps[j].real) Wx_j_i = self.conv1d[j](bps[j].imag) - # UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i) - UWx_j = Wx_j_r**2 + Wx_j_i**2 + UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i) # Avarange over time UWx_j, _ = self.down[j](UWx_j) B, _, N = UWx_j.shape UWx_j = UWx_j.view(B, self.in_channels, self.Q[j], N) UWx.append(UWx_j) - # Frequency range from high to low UWx = torch.cat(UWx, dim=2) return UWx @@ -104,7 +102,6 @@ def to_conv1d(self): device = self.conv1d[0].weight.data.device # T the filter length - # T = self.conv1d[0].kernel_size[0] T = max(self.T) # J the number of levels of decompostion J = self.dtcwt.J @@ -114,12 +111,12 @@ def to_conv1d(self): x[:, :, N//2] = 1 inv = murenn.IDTCWT( J = J, - normalize=True, + alternate_gh=False ).to(device) # Get DTCWT impulse reponses phi, psis = self.dtcwt(x) # Set phi to a zero valued tensor - zeros_phi = phi.new_zeros(1,1,phi.shape[-1]) + zeros_phi = phi.new_zeros(1, 1, phi.shape[-1]) ws = [] for j in range(J): Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1) @@ -127,10 +124,11 @@ def to_conv1d(self): for q in range(self.Q[j]): Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1) Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1) - # Wpsis_j = [torch.complex(Wpsi_jr, Wpsi_ji) if k == j else psis[k].new_zeros(psis[k].shape).repeat(1, self.Q[k], 1) for k in range(J)] - Wpsis_jq = [torch.complex(Wpsi_jqr, Wpsi_jqi) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)] - w_jq = inv(zeros_phi, Wpsis_jq) - ws.append(w_jq) + Wpsis_r = [Wpsi_jqr * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)] + Wpsis_i = [Wpsi_jqi * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)] + w_r = inv(zeros_phi, Wpsis_r) + w_i = inv(zeros_phi, Wpsis_i) + ws.append(torch.complex(w_r, w_i)) ws = torch.cat(ws, dim=0) conv1d = torch.nn.Conv1d( in_channels=1, diff --git a/tests/test_nn.py b/tests/test_nn.py index ce1d396..e2a24b3 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -26,7 +26,7 @@ def test_direct_shape(J, Q, T, N, padding_mode): padding_mode=padding_mode, ) y = graph(x) - assert y.shape[:4] == (B, C, Q, J) + assert y.shape[:3] == (B, C, Q*J) def test_direct_diff(): From 9e176c5b00a092f3ffe506408502b415ec07c30b Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 17:50:23 +0200 Subject: [PATCH 15/18] remove property --- murenn/dtcwt/nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 5805eb5..2c1caa7 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -87,7 +87,6 @@ def forward(self, x): UWx = torch.cat(UWx, dim=2) return UWx - @property def to_conv1d(self): """ Compute the single-resolution equivalent impulse response of the MuReNN layer. From 567d5ae626dea75af5d777b8f96fc80c7d7f158e Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 17:55:18 +0200 Subject: [PATCH 16/18] Update test_nn.py --- tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_nn.py b/tests/test_nn.py index e2a24b3..4d01e11 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -98,5 +98,5 @@ def test_toconv1d_shape(Q, T): T=T, in_channels=2, ) - conv1d = tfm.to_conv1d + conv1d = tfm.to_conv1d() assert isinstance(conv1d, torch.nn.Conv1d) \ No newline at end of file From 001ad4b538f74ad61444d11a02d5fb031cefb23d Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 17:58:55 +0200 Subject: [PATCH 17/18] Update nn.py --- murenn/dtcwt/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 2c1caa7..89fc9de 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -10,7 +10,8 @@ class MuReNNDirect(torch.nn.Module): Args: J (int): Number of levels (octaves) in the DTCWT decomposition. Q (int, dict or list): Number of Conv1D filters per octave. - T (int): Conv1D Kernel size multiplier + T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to + T * Q[j] where Q[j] is the number of filters. J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J. in_channels (int): Number of channels in the input signal. padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate', From 3b002fd684179bff2391975c8a935d6e9a6d23ee Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 18:09:44 +0200 Subject: [PATCH 18/18] Update nn.py --- murenn/dtcwt/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 89fc9de..58b4154 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -9,7 +9,7 @@ class MuReNNDirect(torch.nn.Module): """ Args: J (int): Number of levels (octaves) in the DTCWT decomposition. - Q (int, dict or list): Number of Conv1D filters per octave. + Q (int or list): Number of Conv1D filters per octave. T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to T * Q[j] where Q[j] is the number of filters. J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J. @@ -21,11 +21,11 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" super().__init__() if isinstance(Q, int): self.Q = [Q for j in range(J)] - elif isinstance(Q, (dict, list)): + elif isinstance(Q, list): assert len(Q) == J self.Q = Q else: - raise TypeError(f"Q must to be int, dict or list, got {type(Q)}") + raise TypeError(f"Q must to be int or list, got {type(Q)}") if J_phi is None: J_phi = J if J_phi < J: