Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][PyTorch] Add: Relay stft operator #11190

Merged
merged 4 commits into from May 7, 2022
Merged

[Frontend][PyTorch] Add: Relay stft operator #11190

merged 4 commits into from May 7, 2022

Conversation

ghost
Copy link

@ghost ghost commented Apr 29, 2022

This PR adds the stft, amax, amin
torch.stft

# This is run using 4 cpu cores for 100 iterations average inference time per frame
pt Elapsed time: 0.1826983118057251
graph_runtime Elapsed time: 0.11588997602462768

@ghost ghost changed the title Add: Relay stft operator [Frontend][PyTorch] Add: Relay stft operator Apr 29, 2022
@masahi
Copy link
Member

masahi commented Apr 29, 2022

@jsheng-jian Thanks, looks like topi/stft.py and topi/cuda/stft.py are missing?

I'm very curious how you implemented stft (without fft in TVM)!

python/setup.py Outdated Show resolved Hide resolved
@ghost
Copy link
Author

ghost commented May 1, 2022

topi/stft.py and topi/cuda/stft.py are added, could you take another look? @masahi

@masahi
Copy link
Member

masahi commented May 2, 2022

Please address the CI issue, there is a warning error from doc.

win_length : int
The size of window frame and STFT filter
window : relay.Expr
A 1-D tensor window frame
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch, window argument is optional. So shouldn't we support that too?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added optional window argument.

Returns
-------
output : relay.Expr
Tensor containing the STFT result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the output shape. I had to read the type rel to see how the output shape looks like.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


with ib.for_range(0, output_ptr.shape[0]) as batch:
with ib.for_range(0, output_ptr.shape[1]) as row:
with ib.for_range(0, output_ptr.shape[2]) as col:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird. You try to parallelize over the batch dim but nothing is parallelized.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated Cuda implementation.


with ib.for_range(0, output_ptr.shape[0]) as batch:
# https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft
with ib.for_range(0, output_ptr.shape[1], kind="parallel") as row:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse the outer loop to have one big parallel loop.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf")
loop_kind = "vectorize"
if hasattr(output_shape[2], "name") and output_shape[2].name == "any_dim":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(output_shape[2], tir.expr.Any)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is tir.expr.SizeVar, updated.

verify_trace_model(test_fn(3, 3, 3, False, "reflect", False, True), [input, window], targets)
window = torch.tensor([1, 3], dtype=torch.int32)
verify_trace_model(test_fn(2, 1, 2, False, "reflect", False, True), [input, window], targets)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for window=None case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
relay.backend.te_compiler.get().clear()


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this diff

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def stft(data, n_fft, hop_length, win_length, window, normalized, onesided):
"""
The STFT computes the Fourier transform of short overlapping windows of the input.
This giving frequency components of the signal as they change over time.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

giving -> gives

Fix the same typo in other files too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

)


def verify_func2(target, dev, func, data, ref_res, rtol=1e-5, atol=1e-7, kinds=["vm"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this? It looks identical to verify_func.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to expost rtol atol and kinds, I can update the original function instead of creating a new one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be better

win_length,
window,
normalized,
onesided, # pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove # pylint: disable=unused-argument and add unused-argument to L17.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

win_length,
window_ptr,
normalized,
onesided, # pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

win_length,
window,
normalized,
onesided, # pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment in topi/cuda/stft.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

win_length,
window_ptr,
normalized,
onesided, # pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@ghost
Copy link
Author

ghost commented May 5, 2022

The CI failure seems to be unrelated.

@ghost
Copy link
Author

ghost commented May 6, 2022

It looks the CI failure is resolved in main branch, do I need to rebase my changes?

@masahi
Copy link
Member

masahi commented May 6, 2022

Yes, please send another job

@masahi masahi merged commit a3d75ae into apache:main May 7, 2022
shtinsa pushed a commit to Deelvin/tvm that referenced this pull request May 17, 2022
* Add: Relay stft operator

* fix doc

* address PR comments

* address addtional comments
SebastianBoblest pushed a commit to SebastianBoblest/tvm that referenced this pull request May 27, 2022
* Add: Relay stft operator

* fix doc

* address PR comments

* address addtional comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant