-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][PyTorch] Add: Relay stft operator #11190
Conversation
@jsheng-jian Thanks, looks like I'm very curious how you implemented stft (without fft in TVM)! |
|
Please address the CI issue, there is a warning error from doc. |
python/tvm/relay/op/transform.py
Outdated
win_length : int | ||
The size of window frame and STFT filter | ||
window : relay.Expr | ||
A 1-D tensor window frame |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch, window argument is optional. So shouldn't we support that too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added optional window argument.
python/tvm/relay/op/transform.py
Outdated
Returns | ||
------- | ||
output : relay.Expr | ||
Tensor containing the STFT result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the output shape. I had to read the type rel to see how the output shape looks like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
python/tvm/topi/cuda/stft.py
Outdated
|
||
with ib.for_range(0, output_ptr.shape[0]) as batch: | ||
with ib.for_range(0, output_ptr.shape[1]) as row: | ||
with ib.for_range(0, output_ptr.shape[2]) as col: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks weird. You try to parallelize over the batch dim but nothing is parallelized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated Cuda implementation.
python/tvm/topi/stft.py
Outdated
|
||
with ib.for_range(0, output_ptr.shape[0]) as batch: | ||
# https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft | ||
with ib.for_range(0, output_ptr.shape[1], kind="parallel") as row: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fuse the outer loop to have one big parallel loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/tvm/topi/stft.py
Outdated
|
||
output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf") | ||
loop_kind = "vectorize" | ||
if hasattr(output_shape[2], "name") and output_shape[2].name == "any_dim": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(output_shape[2], tir.expr.Any)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is tir.expr.SizeVar
, updated.
verify_trace_model(test_fn(3, 3, 3, False, "reflect", False, True), [input, window], targets) | ||
window = torch.tensor([1, 3], dtype=torch.int32) | ||
verify_trace_model(test_fn(2, 1, 2, False, "reflect", False, True), [input, window], targets) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test for window=None
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) | ||
relay.backend.te_compiler.get().clear() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/tvm/relay/op/transform.py
Outdated
def stft(data, n_fft, hop_length, win_length, window, normalized, onesided): | ||
""" | ||
The STFT computes the Fourier transform of short overlapping windows of the input. | ||
This giving frequency components of the signal as they change over time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
giving -> gives
Fix the same typo in other files too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/relay/test_op_level3.py
Outdated
) | ||
|
||
|
||
def verify_func2(target, dev, func, data, ref_res, rtol=1e-5, atol=1e-7, kinds=["vm"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this? It looks identical to verify_func
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to expost rtol
atol
and kinds
, I can update the original function instead of creating a new one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would be better
python/tvm/topi/cuda/stft.py
Outdated
win_length, | ||
window, | ||
normalized, | ||
onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove # pylint: disable=unused-argument
and add unused-argument
to L17.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/tvm/topi/cuda/stft.py
Outdated
win_length, | ||
window_ptr, | ||
normalized, | ||
onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/tvm/topi/stft.py
Outdated
win_length, | ||
window, | ||
normalized, | ||
onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the comment in topi/cuda/stft.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/tvm/topi/stft.py
Outdated
win_length, | ||
window_ptr, | ||
normalized, | ||
onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
The CI failure seems to be unrelated. |
It looks the CI failure is resolved in main branch, do I need to rebase my changes? |
Yes, please send another job |
* Add: Relay stft operator * fix doc * address PR comments * address addtional comments
* Add: Relay stft operator * fix doc * address PR comments * address addtional comments
This PR adds the stft, amax, amin
torch.stft