Skip to content

Commit

Permalink
Update test_dtcwt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Jun 22, 2024
1 parent b5e49e6 commit fe0dda4
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_fwd_same(J):
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("J", list(range(1, 5)))
@pytest.mark.parametrize("T", [44099, 44100])
def test_inv(level1, qshift, J, T, alternate_gh, normalize):
def test_pr(level1, qshift, J, T, alternate_gh, normalize):
Xt = torch.randn(2, 2, T)
xfm_murenn = murenn.DTCWTDirect(
J=J,
Expand Down Expand Up @@ -75,3 +75,17 @@ def test_skip_hps(skip_hps, include_scale):
inv = murenn.DTCWTInverse(J=J, skip_hps=skip_hps, include_scale=include_scale)
X_rec = inv(lp, bp)
assert X_rec.shape == Xt.shape

def test_inv():
T = 2**10
Xt = torch.randn(2, 2, T)
dtcwt = murenn.DTCWTDirect()
idtcwt = murenn.DTCWTInverse()
lp, bp = dtcwt(Xt)
lp = lp.new_zeros(lp.shape)
X_rec = idtcwt(lp, bp)
bp_r = [(bp[j].real)*(1+0j) for j in range(8)]
bp_i = [(bp[j].imag)*(0+1j) for j in range(8)]
X_rec_r = idtcwt(lp, bp_r)
X_rec_i = idtcwt(lp, bp_i)
assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3)

0 comments on commit fe0dda4

Please sign in to comment.