From fe0dda4aaa913665508b83c7a6020d87c4f64fe5 Mon Sep 17 00:00:00 2001 From: Xiran Date: Sat, 22 Jun 2024 12:28:20 +0200 Subject: [PATCH] Update test_dtcwt.py --- tests/test_dtcwt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_dtcwt.py b/tests/test_dtcwt.py index 873b5d3..78f6d15 100644 --- a/tests/test_dtcwt.py +++ b/tests/test_dtcwt.py @@ -39,7 +39,7 @@ def test_fwd_same(J): @pytest.mark.parametrize("normalize", [True, False]) @pytest.mark.parametrize("J", list(range(1, 5))) @pytest.mark.parametrize("T", [44099, 44100]) -def test_inv(level1, qshift, J, T, alternate_gh, normalize): +def test_pr(level1, qshift, J, T, alternate_gh, normalize): Xt = torch.randn(2, 2, T) xfm_murenn = murenn.DTCWTDirect( J=J, @@ -75,3 +75,17 @@ def test_skip_hps(skip_hps, include_scale): inv = murenn.DTCWTInverse(J=J, skip_hps=skip_hps, include_scale=include_scale) X_rec = inv(lp, bp) assert X_rec.shape == Xt.shape + +def test_inv(): + T = 2**10 + Xt = torch.randn(2, 2, T) + dtcwt = murenn.DTCWTDirect() + idtcwt = murenn.DTCWTInverse() + lp, bp = dtcwt(Xt) + lp = lp.new_zeros(lp.shape) + X_rec = idtcwt(lp, bp) + bp_r = [(bp[j].real)*(1+0j) for j in range(8)] + bp_i = [(bp[j].imag)*(0+1j) for j in range(8)] + X_rec_r = idtcwt(lp, bp_r) + X_rec_i = idtcwt(lp, bp_i) + assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3) \ No newline at end of file