Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH+FIX to_conv1d() and energy bugfix in MuReNNDirect #53

Closed
wants to merge 14 commits into from
Closed

Conversation

xir4n
Copy link
Collaborator

@xir4n xir4n commented Jun 10, 2024

closes #45, resolves #47 , resolves #52

Update:

  • fix bug: normalization behavior inside dtcwt.transform1d
  • new feature: MuReNNDirect.to_conv1d
  • fix bug: normalization setting in MuReNNDirect

@xir4n xir4n requested a review from lostanlen June 10, 2024 15:00
@xir4n xir4n removed the request for review from lostanlen June 11, 2024 18:23
@lostanlen lostanlen changed the title ENH MuReNNDirect.to_con1d() ENH MuReNNDirect.to_conv1d() Jun 12, 2024
Copy link
Contributor

@lostanlen lostanlen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you but this implementation returns an array. I would prefer if to_conv1d() could return a fully functional Conv1D module whose weight.data contains the array you have. The padding should be the same.

Then in tests/ you can check that conv1d(x) is almost equal to murenn(x)

OK?

@xir4n
Copy link
Collaborator Author

xir4n commented Jun 12, 2024

ok, I think I've got what you mean

@lostanlen lostanlen changed the title ENH MuReNNDirect.to_conv1d() ENH+FIX to_conv1d() and energy bugfix in MuReNNDirect Jun 17, 2024
@lostanlen
Copy link
Contributor

Following our conversation today, this is both ENH and FIX now. Let me know when it's ready so I can review.

@xir4n xir4n requested a review from lostanlen June 17, 2024 17:17
@xir4n
Copy link
Collaborator Author

xir4n commented Jun 18, 2024

@lostanlen What do you think I should do? Can I specify the numpy version in the continuous integration script to make this PR pass the checks?

@lostanlen
Copy link
Contributor

tests are passing on this environment

(murenn) ➜  murenn git:(fafe691) conda list
# packages in environment at /Users/user/miniconda3/envs/murenn:
#
# Name                    Version                   Build  Channel
bzip2                     1.0.8                h6c40b1e_6
ca-certificates           2024.3.11            hecd8cb5_0
dtcwt                     0.14.0                   pypi_0    pypi
exceptiongroup            1.2.1                    pypi_0    pypi
filelock                  3.15.3                   pypi_0    pypi
fsspec                    2024.6.0                 pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
libffi                    3.4.4                hecd8cb5_1
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
murenn                    0.3.dev0                  dev_0    <develop>
ncurses                   6.4                  hcec6c5f_0
networkx                  3.3                      pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
openssl                   3.0.14               h46256e1_0
packaging                 24.1                     pypi_0    pypi
pip                       24.0                     pypi_0    pypi
pluggy                    1.5.0                    pypi_0    pypi
pytest                    8.2.2                    pypi_0    pypi
python                    3.10.14              h5ee71fb_1
readline                  8.2                  hca72f7f_0
scipy                     1.13.1                   pypi_0    pypi
setuptools                69.5.1                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sqlite                    3.45.3               h6c40b1e_0
sympy                     1.12.1                   pypi_0    pypi
tk                        8.6.14               h4d00af3_0
tomli                     2.0.1                    pypi_0    pypi
torch                     2.2.2                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2024a                h04d1e81_0
wheel                     0.43.0                   pypi_0    pypi
xz                        5.4.6                h6c40b1e_1
zlib                      1.2.13               h4b97444_1

@lostanlen
Copy link
Contributor

Collecting numpy<2.0.0 (from dtcwt>=0.13.0->murenn==0.3.dev0)
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)

😻

Copy link
Contributor

@lostanlen lostanlen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. i just have a question on the zero-ing of complex-valued coefficients

Wpsi_ji = self.conv1d[j](psis[j].imag)
# Set the coefficients besides this scale to zero
Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)]
Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert to complex? and why multiply by 1j?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I converted them to complex because you wanted the to_conv1d method to return a complex-valued PyTorch tensor in #45.

Besides, in dtcwt, we have: (I've uploaded a test to show this. )

$$ \begin{align} idtcwt(x) &= idtcwt.tree_a(x) + idtcwt.tree_b(x)\\ &= idtcwt(x.real) + idtcwt(x.imag) \end{align} $$

I think in the returned Conv1d instance, conv1d.weight.real+conv1d.weight.imag is equivalent to the case if we don't convert Wpsi_j to complex

there is a bug: maybe i should use $x=(1-1j)*\delta$ instead of $x=(1+0j) * \delta$ to convolve with conv1d to visualize the filter

I can change this into a real-valued version if this is too complicated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, i don't think $x = (1 - 1j) \delta$ is necessary. $x = \delta$ should suffice.

My question was not a mathematical question but more like a CS question. Usually when you do
z = torch.complex(re, im)
re and im are both real-valued. Here this is not the case, since you have multiplied im by 1j.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, I've made the change.
I was not clear in my previous answer. In the algorithm, only the real part of a complex tensor would be passed through Tree A, and only the imaginary part would be passed through Tree B. That's why I was converging z into a complex-valued tensor.

@xir4n xir4n closed this Jul 23, 2024
@lostanlen
Copy link
Contributor

closed because of too high complexity
converted into new PR's: --> #55 and #56

@xir4n xir4n deleted the to_con1d branch August 18, 2024 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Energy issue in MuReNNDirect normalization issue for DTCWTDirect MuReNNDirect.to_conv1d()
2 participants