-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH+FIX to_conv1d()
and energy bugfix in MuReNNDirect
#53
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you but this implementation returns an array. I would prefer if to_conv1d()
could return a fully functional Conv1D
module whose weight.data
contains the array you have. The padding should be the same.
Then in tests/
you can check that conv1d(x)
is almost equal to murenn(x)
OK?
ok, I think I've got what you mean |
MuReNNDirect.to_conv1d()
to_conv1d()
and energy bugfix in MuReNNDirect
Following our conversation today, this is both ENH and FIX now. Let me know when it's ready so I can review. |
@lostanlen What do you think I should do? Can I specify the |
tests are passing on this environment
|
😻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks. i just have a question on the zero-ing of complex-valued coefficients
murenn/dtcwt/nn.py
Outdated
Wpsi_ji = self.conv1d[j](psis[j].imag) | ||
# Set the coefficients besides this scale to zero | ||
Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] | ||
Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why convert to complex? and why multiply by 1j
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I converted them to complex because you wanted the to_conv1d method to return a complex-valued PyTorch tensor in #45.
Besides, in dtcwt, we have: (I've uploaded a test to show this. )
I think in the returned Conv1d
instance, conv1d.weight.real+conv1d.weight.imag
is equivalent to the case if we don't convert Wpsi_j
to complex
there is a bug: maybe i should use conv1d
to visualize the filter
I can change this into a real-valued version if this is too complicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, i don't think
My question was not a mathematical question but more like a CS question. Usually when you do
z = torch.complex(re, im)
re
and im
are both real-valued. Here this is not the case, since you have multiplied im
by 1j
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your advice, I've made the change.
I was not clear in my previous answer. In the algorithm, only the real part of a complex tensor would be passed through Tree A, and only the imaginary part would be passed through Tree B. That's why I was converging z
into a complex-valued tensor.
closes #45, resolves #47 , resolves #52
Update:
dtcwt.transform1d
MuReNNDirect