-
Notifications
You must be signed in to change notification settings - Fork 666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for torch STFT & spectrogram #1824
Conversation
nikalra
commented
Apr 10, 2023
- Adds support for torch.stft and its various options for both complex and real inputs
- Adds support for torchaudio.functional.Spectrogram (and MelSpectrogram) via STFT support and complex support for pad/reshape/abs.
view = mb.reshape(x=x, shape=shape, name=node.name) | ||
|
||
if types.is_complex(x.dtype): | ||
real, imag = (mb.reshape(x=x, shape=shape, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted not to create complex dialect ops for reshape and pad (below) because their behavior doesn't change as a result of the inputs being complex.
I'm more than happy to create a complex dialect op for these if that's the preferred approach, but figured that this might be a better route to avoid duplicating each built-in op as a complex dialect op.
If in the future, there's support for something like a lowering pass where all non-complex dialect ops with complex support in their type domain can be duplicated across the real and imaginary components of the input, this would probably be easier to get rid of and restore to just the code in the else block.
@@ -285,7 +285,7 @@ def type_domain(self): | |||
|
|||
@type_domain.setter | |||
def type_domain(self, val): | |||
msg = "type_domain must be a tuple of builtin types" | |||
msg = f"type_domain {val} must be a tuple of builtin types" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an unrelated change but made errors during the complex lowering passes a little easier to debug.
@@ -732,13 +732,128 @@ class complex_shape(Operation): | |||
def type_inference(self): | |||
if not isinstance(self.x, ComplexVar): | |||
raise ValueError("x must be a ComplexVar.") | |||
input_rank = self.x.real.rank | |||
input_rank = self.x.rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If type_inference or value_inference is invoked when the graph is being constructed, x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Could you add this as a comment here?
:O It's happening!!! :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice PR! Added several comments. After those comments are addressed, I will kick off a CI run and merge it. Thanks a lot!
@@ -8099,6 +8100,113 @@ def forward(self, x): | |||
(2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit | |||
) | |||
|
|||
class TestSTFT(TorchBaseTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Very nice and comprehensive tests!
atol=1e-4, | ||
) | ||
|
||
class TestComplex(TorchBaseTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a TestComplex
test class in this file. Could you merge this test_abs
into that test class?
@@ -732,13 +732,128 @@ class complex_shape(Operation): | |||
def type_inference(self): | |||
if not isinstance(self.x, ComplexVar): | |||
raise ValueError("x must be a ComplexVar.") | |||
input_rank = self.x.real.rank | |||
input_rank = self.x.rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Could you add this as a comment here?
Done! Let me know if I missed anything! |
CI Run: https://gitlab.com/coremltools1/coremltools/-/pipelines/835422539 The test for |
As this test behavior is machine-dependent and failed on Intel Mac, we xfail it for now and will debug later. Already filed an internal tracking radar for it.
I also tested it on an Intel Mac, and it also failed (but not the same error message). So the conclusion is that the Starting a new CI run after marking xfail: https://gitlab.com/coremltools1/coremltools/-/pipelines/836614946 |
Very excited to see |