Skip to content

Commit

Permalink
fix to_conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Nov 29, 2024
1 parent fc4b95e commit 74d7b17
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
35 changes: 18 additions & 17 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def to_conv1d(self):
Compute the single-resolution equivalent impulse response of the MuReNN layer.
This would be helpful for visualization in Fourier domain, for receptive fields,
and for comparing computational costs.
DTCWT conv1d IDTCWT
δ -------> ψ_j --------> w_jq -------> y_jq
conv1d IDTCWT
δ --------> w_jq -------> y_jq
-------
Return:
conv1ds: A dictionary containing PyTorch Conv1d instances with weights initialized to y_jq.
Expand All @@ -99,32 +99,33 @@ def to_conv1d(self):
device = self.conv1d[0].weight.data.device
T = self.T # Filter length
J = self.dtcwt.J # Number of levels of decomposition
N = 2 ** J * max(T, len(self.dtcwt.g0a)) # Length of the impulse signal
N = 2 ** J * max(T, len(self.dtcwt.g0a)) # Hybrid filter length

# Generate the impulse signal
# Generate a zero signal
x = torch.zeros(1, self.in_channels, N).to(device)
x[:, :, N//2] = 1

# Initialize the inverse DTCWT
inv = murenn.IDTCWT(J=J, alternate_gh=False).to(device)

# Get DTCWT impulse reponses
# Obtain two dual-tree response of the zero signal
phi, psis = self.dtcwt(x)
zeros_phi = phi.new_zeros(1, 1, phi.shape[-1])
phi = phi[0,0,:].reshape(1,1,-1) # We only need the first channel

ws_r, ws_i = [], []
for j in range(J):
Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1)
Wpsi_ji = self.conv1d[j](psis[j].imag).reshape(self.in_channels, self.Q[j], -1)
# Set the level-j response to a impulse signal
psi_j = psis[j].real
psi_j[:, :, psi_j.shape[2]//2] = 1 / math.sqrt(2) ** j # The energy gain
# Convolve the impulse signal with the conv1d filter
Wpsi_j = self.conv1d[j](psi_j).reshape(self.in_channels, self.Q[j], -1)
# Apply dual-tree invert transform to obtain the hybrid wavelets.
for q in range(self.Q[j]):
Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1)
Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1)
Wpsi_jq = Wpsi_j[0, q, :].reshape(1,1,-1)
Wpsis_r = [Wpsi_jq * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
Wpsis_i = [Wpsi_jq * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]

Wpsis_r = [Wpsi_jqr * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
Wpsis_i = [Wpsi_jqi * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]

ws_r.append(inv(zeros_phi, Wpsis_r))
ws_i.append(inv(zeros_phi, Wpsis_i))
ws_r.append(inv(phi, Wpsis_r))
ws_i.append(inv(phi, Wpsis_i))

ws_r = torch.cat(ws_r, dim=0)
ws_i = torch.cat(ws_i, dim=0)
Expand All @@ -141,7 +142,7 @@ def create_conv1d(weight):
return conv1d

return {
"complex": create_conv1d(ws_r+ws_i),
"complex": create_conv1d(ws_r+1j*ws_i),
"real": create_conv1d(ws_r),
"imag": create_conv1d(ws_i),
}
Expand Down
25 changes: 22 additions & 3 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import torch
import murenn
import math

from murenn.dtcwt.nn import ModulusStable, Downsampling

Expand Down Expand Up @@ -90,7 +91,7 @@ def test_modulus():
assert torch.max(torch.abs(x0i.grad)) <= 1e-7

@pytest.mark.parametrize("Q", [1, 2])
@pytest.mark.parametrize("T", [1, 2])
@pytest.mark.parametrize("T", [2, 3])
def test_toconv1d(Q, T):
J = 4
tfm = murenn.MuReNNDirect(
Expand All @@ -100,11 +101,29 @@ def test_toconv1d(Q, T):
in_channels=2,
)
N = 2**J*16
x = torch.randn(1, 1, N)
x = torch.zeros(1, 1, N)
x[:,:,N//2] = 1
conv1ds = tfm.to_conv1d()
for conv1d in conv1ds.values():
for conv1d in [conv1ds["real"], conv1ds["imag"]]:
assert isinstance(conv1d, torch.nn.Conv1d)
y = conv1d(x)
assert isinstance(y, torch.Tensor)
assert y.dtype == x.dtype
assert y.shape == (1, J*Q, N)

# Test the energy gain
tfm = murenn.MuReNNDirect(
J=J,
Q=1,
T=1,
in_channels=1,
)
# Initialize the learnable filters with dirac
for conv1d in tfm.conv1d:
torch.nn.init.dirac_(conv1d.weight)
# Get the dtcwt filters
psis = tfm.to_conv1d()
# Test the energy crossing subbands
y = psis["complex"](torch.complex(x, x)/math.sqrt(2))
energy = torch.linalg.norm(y, dim=-1)
assert torch.allclose(energy, torch.ones(1, J), atol=0.1)

0 comments on commit 74d7b17

Please sign in to comment.