Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change cached seq_len to int to enable compilation #38

Merged
merged 1 commit into from
Nov 27, 2024

Conversation

f0k
Copy link
Contributor

@f0k f0k commented Nov 27, 2024

First of all, thanks for the nicely reusable package!

With recent versions of PyTorch, torch.compile() fails with the error "RuntimeError: aten::copy() Expected a value of type 'Tensor' for argument 'src' but instead found type 'int'." on the line that does self.cached_freqs_seq_len.copy_(seq_len), where seq_len is an int.
Also it warns about a "Graph break from Tensor.item()" on the line that has (offset + seq_len) <= self.cached_freqs_seq_len.item().

This PR fixes both by changing cached_freqs_seq_len and cached_scales_seq_len from singleton int tensors to plain Python integers. Forgive me if I overlooked anything, but it seems to me that there is no benefit of having these values on the GPU?

@lucidrains
Copy link
Owner

@f0k thank you for the PR Jan! yes indeed, not sure why it was stored that way

@lucidrains lucidrains merged commit 8f2ccce into lucidrains:main Nov 27, 2024
@f0k f0k deleted the int-seq-len branch November 28, 2024 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants