Skip to content

Commit

Permalink
FIX order of operators in MuReNNDirect (#38)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Lostanlen <[email protected]>
  • Loading branch information
xir4n and lostanlen authored Apr 4, 2024
1 parent 9172bac commit e03dedb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def forward(self, x):
ys = []

for j in range(self.dtcwt.J):
x_j = torch.abs(bps[j])
x_j = self.conv1d[j](x_j)
y_j, _ = self.down[j](x_j)
Wx_r = self.conv1d[j](bps[j].real)
Wx_i = self.conv1d[j](bps[j].imag)
Ux = Wx_r ** 2 + Wx_i ** 2
y_j, _ = self.down[j](Ux)

B, _, N = y_j.shape
# reshape from (B, C*Q, N) to (B, C, Q, N)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_graph.py → tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@pytest.mark.parametrize("Q", [3, 4])
@pytest.mark.parametrize("T", [8, 16])
@pytest.mark.parametrize("padding_mode", ["symmetric", "zeros"])
def test_shape(J, Q, T, padding_mode):
def test_direct_shape(J, Q, T, padding_mode):
B, C, N = 2, 3, 2**(J+4)
x = torch.zeros(B, C, N)
graph = murenn.MuReNNDirect(
Expand All @@ -19,8 +19,9 @@ def test_shape(J, Q, T, padding_mode):
)
y = graph(x)
assert y.shape == (B, C, Q, J, 2**4)


def test_diff():
def test_direct_diff():
J, Q, T = 3, 4, 4
B, C, N = 2, 3, 2**(J+4)
x = torch.zeros(B, C, N)
Expand Down

0 comments on commit e03dedb

Please sign in to comment.