From e03dedb7b19ae2f49e703a6f3e526d27a8771fc6 Mon Sep 17 00:00:00 2001 From: Xiran <97965893+xir4n@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:27:31 +0200 Subject: [PATCH] FIX order of operators in MuReNNDirect (#38) Co-authored-by: Vincent Lostanlen --- murenn/dtcwt/nn.py | 7 ++++--- tests/{test_graph.py => test_nn.py} | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) rename tests/{test_graph.py => test_nn.py} (91%) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 211c99e..8e8df26 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -59,9 +59,10 @@ def forward(self, x): ys = [] for j in range(self.dtcwt.J): - x_j = torch.abs(bps[j]) - x_j = self.conv1d[j](x_j) - y_j, _ = self.down[j](x_j) + Wx_r = self.conv1d[j](bps[j].real) + Wx_i = self.conv1d[j](bps[j].imag) + Ux = Wx_r ** 2 + Wx_i ** 2 + y_j, _ = self.down[j](Ux) B, _, N = y_j.shape # reshape from (B, C*Q, N) to (B, C, Q, N) diff --git a/tests/test_graph.py b/tests/test_nn.py similarity index 91% rename from tests/test_graph.py rename to tests/test_nn.py index 6fee18e..9aacb2c 100644 --- a/tests/test_graph.py +++ b/tests/test_nn.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("Q", [3, 4]) @pytest.mark.parametrize("T", [8, 16]) @pytest.mark.parametrize("padding_mode", ["symmetric", "zeros"]) -def test_shape(J, Q, T, padding_mode): +def test_direct_shape(J, Q, T, padding_mode): B, C, N = 2, 3, 2**(J+4) x = torch.zeros(B, C, N) graph = murenn.MuReNNDirect( @@ -19,8 +19,9 @@ def test_shape(J, Q, T, padding_mode): ) y = graph(x) assert y.shape == (B, C, Q, J, 2**4) + -def test_diff(): +def test_direct_diff(): J, Q, T = 3, 4, 4 B, C, N = 2, 3, 2**(J+4) x = torch.zeros(B, C, N)