Skip to content

Commit

Permalink
Update nn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Jan 14, 2025
1 parent 8c9c418 commit 5091e8f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def forward(self, x):
lp, bps = self.dtcwt(x)
UWx = []
for j in range(self.dtcwt.J):
Wx_j_r = self.conv1d[j](bps[j].real)
Wx_j_i = self.conv1d[j](bps[j].imag)
UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i) + ModulusStable.apply(bps[j].real, bps[j].imag).repeat(1, self.Q[j], 1)
Wx_j_r = self.conv1d[j](bps[j].real) + bps[j].real.repeat(1, self.Q[j], 1)
Wx_j_i = self.conv1d[j](bps[j].imag) + bps[j].imag.repeat(1, self.Q[j], 1)
UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i)
UWx_j = self.down[j](UWx_j)
B, _, N = UWx_j.shape
UWx_j = UWx_j.view(B, self.in_channels, self.Q[j], N)
Expand Down

0 comments on commit 5091e8f

Please sign in to comment.