Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX nn ModuleList and ParameterList #44

Merged
merged 23 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
super().__init__()
self.Q = Q
self.C = in_channels
self.down = []
self.conv1d = []
down = []
conv1d = []
self.dtcwt = murenn.DTCWT(
J=J,
padding_mode=padding_mode,
)

for j in range(J):
down = murenn.DTCWT(
down_j = murenn.DTCWT(
J=J-j,
padding_mode=padding_mode,
skip_hps=True,
)
self.down.append(down)
down.append(down_j)

conv1d = torch.nn.Conv1d(
conv1d_j = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=Q*in_channels,
kernel_size=T,
bias=False,
groups=in_channels,
padding="same",
)
torch.nn.init.normal_(conv1d.weight)
self.conv1d.append(conv1d)
torch.nn.init.normal_(conv1d_j.weight)
conv1d.append(conv1d_j)
self.down = torch.nn.ModuleList(down)
self.conv1d = torch.nn.ParameterList(conv1d)


def forward(self, x):
Expand Down
28 changes: 17 additions & 11 deletions murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,25 @@ def forward(self, x):
x (PyTorch tensor): Input data. Should be a tensor of shape
`(B, C, T)` where B is the batch size, C is the number of
channels and T is the number of time samples.
Note that T must be a multiple of 2**J, where J is the number
of wavelet scales (see documentation of DTCWTDirect constructor).
Note that T must be even length.

Returns:
yl: low-pass coefficients. If include_scale is True (see DTCWTDirect
constructor), yl is a list of low-pass coefficients at all wavelet
scales 1 to (J-1). Otherwise (default), yl is a real-valued PyTorch
tensor of shape `(B, C, T/2**(J-1))`.
tensor.
yh: band-pass coefficients. A list of PyTorch tensors with J elements,
containing the band-pass coefficients at all wavelets scales 1 to
(J-1). These tensors are complex-valued and have shapes:
`(B, C, T)`, `(B, C, T/2)`, `(B, C, T/4)`, etc."""
(J-1). These tensors are complex-valued."""

# Initialize lists of empty arrays with same dtype as input
x_phis = []
x_psis = []

# Assert that the length of x is a multiple of 2**J
# Extend if the length of x is not even
T = x.shape[-1]
assert T % (2**self.J) == 0
if T % 2 != 0:
x = torch.cat((x, x[:,:,-1:]), dim=-1)

## LEVEL 1 ##
x_phi, x_psi_r, x_psi_i = FWD_J1.apply(
Expand All @@ -153,6 +152,10 @@ def forward(self, x):
else:
h0a, h1a, h0b, h1b = self.h0a, self.h1a, self.h0b, self.h1b

# Ensure the lowpass is divisible by 4
if x_phi.shape[-1] % 4 != 0:
x_phi = torch.cat((x_phi[:,:,0:1], x_phi, x_phi[:,:,-1:]), dim=-1)

x_phi, x_psi_r, x_psi_i = FWD_J2PLUS.apply(
x_phi,
h0a,
Expand Down Expand Up @@ -245,11 +248,10 @@ def forward(self, yl, yh):
yl: low-pass coefficients for the DTCWT reconstruction. If include_scale
is True (see DTCWTInverse constructor), yl should be a list of low-pass
coefficients at all wavelet scales 1 to (J-1). Otherwise (default),
yl should be a real-valued PyTorch tensor of shape `(B, C, T/2**(J-1))`.
yl should be a real-valued PyTorch tensor.
yh: band-pass coefficients for the DTCWT reconstruction. A list of PyTorch
tensors with J elements, containing the band-pass coefficients at all
wavelets scales 1 to (J-1). These tensors are complex-valued and must
have shapes: `(B, C, T)`, `(B, C, T/2)`, `(B, C, T/4)`, etc.
wavelets scales 1 to (J-1). These tensors are complex-valued.
"""

# x_phi the low-pass, x_psis the band-pass
Expand All @@ -267,9 +269,13 @@ def forward(self, yl, yh):
# The band-pass coefficients at level j
# Check the length of the band-pass, low-pass input coefficients
x_psi = x_psis[j]

if x_phi.shape[-1] != x_psi.shape[-1] * 2:
x_phi = x_phi[:,:,1:-1]
assert (
x_psi.shape[-1] * 2 == x_phi.shape[-1]
), f"J={j}\n{x_psi.shape[-1]*2}\n{x_phi.shape[-1]}"

if (j % 2 == 1) and self.alternate_gh:
x_psi = torch.conj(x_psi)
g0a, g1a, g0b, g1b = self.h0a, self.h1a, self.h0b, self.h1b
Expand All @@ -291,8 +297,8 @@ def forward(self, yl, yh):

## LEVEL 1 ##
x_psi_r, x_psi_i = x_psis[0].real, x_psis[0].imag

x_phi = INV_J1.apply(
x_phi, x_psi_r, x_psi_i, self.g0o, self.g1o, self.padding_mode
)

return x_phi
13 changes: 8 additions & 5 deletions tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
@pytest.mark.parametrize("J", list(range(1, 10)))
def test_fwd_same(J):
decimal = 4
X = np.random.rand(2**J)
Xt = torch.tensor(X, dtype=torch.get_default_dtype()).view(1, 1, 2**J)
X = np.random.rand(44100)
Xt = torch.tensor(X, dtype=torch.get_default_dtype()).view(1, 1, 44100)
xfm_murenn = murenn.DTCWTDirect(
J=J,
alternate_gh=False,
Expand Down Expand Up @@ -38,8 +38,9 @@ def test_fwd_same(J):
@pytest.mark.parametrize("alternate_gh", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("J", list(range(1, 5)))
def test_inv(level1, qshift, J, alternate_gh, normalize):
Xt = torch.randn(2, 2, 2**J)
@pytest.mark.parametrize("T", [44099, 44100])
def test_inv(level1, qshift, J, T, alternate_gh, normalize):
Xt = torch.randn(2, 2, T)
xfm_murenn = murenn.DTCWTDirect(
J=J,
level1=level1,
Expand All @@ -60,14 +61,16 @@ def test_inv(level1, qshift, J, alternate_gh, normalize):
normalize=normalize,
)
X_rec = inv(lp, bp)
if T % 2 != 0:
X_rec = X_rec[:, :, :-1]
torch.testing.assert_close(Xt, X_rec)


@pytest.mark.parametrize("include_scale", [False, [0, 0, 1]])
@pytest.mark.parametrize("skip_hps", [False, [0, 1, 0]])
def test_skip_hps(skip_hps, include_scale):
J = 3
Xt = torch.randn(2, 2, 2**J)
Xt = torch.randn(2, 2, 44100)
xfm_murenn = murenn.DTCWTDirect(J=J, skip_hps=skip_hps, include_scale=include_scale)
lp, bp = xfm_murenn(Xt)
inv = murenn.DTCWTInverse(J=J, skip_hps=skip_hps, include_scale=include_scale)
Expand Down
29 changes: 24 additions & 5 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
@pytest.mark.parametrize("Q", [3, 4])
@pytest.mark.parametrize("T", [8, 16])
@pytest.mark.parametrize("padding_mode", ["symmetric", "zeros"])
def test_direct_shape(J, Q, T, padding_mode):
B, C, N = 2, 3, 2**(J+4)
x = torch.zeros(B, C, N)
@pytest.mark.parametrize("N", list(range(10)))
def test_direct_shape(J, Q, T, N, padding_mode):
B, C, L = 2, 3, 2**J+N
x = torch.zeros(B, C, L)
graph = murenn.MuReNNDirect(
J=J,
Q=Q,
Expand All @@ -18,7 +19,7 @@ def test_direct_shape(J, Q, T, padding_mode):
padding_mode=padding_mode,
)
y = graph(x)
assert y.shape == (B, C, Q, J, 2**4)
assert y.shape[:4] == (B, C, Q, J)


def test_direct_diff():
Expand All @@ -34,4 +35,22 @@ def test_direct_diff():
y = graph(x)
y.mean().backward()
for conv1d in graph.conv1d:
assert conv1d.weight.grad != None
assert conv1d.weight.grad != None


@pytest.mark.parametrize("J", list(range(1, 3)))
@pytest.mark.parametrize("Q", [3, 4])
@pytest.mark.parametrize("T", [8, 16])
@pytest.mark.parametrize("N", list(range(5)))
def test_multi_layers(J, Q, T, N):
B, C, L = 2, 3, 2**J+N
x = torch.zeros(B, C, L)
for i in range(3):
x = x.view(B, -1, x.shape[-1])
layer_i = murenn.MuReNNDirect(
J=J,
Q=Q,
T=T,
in_channels=x.shape[1],
)
x = layer_i(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that the layer is of the right type ?