Change cached seq_len to int to enable compilation #38
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
First of all, thanks for the nicely reusable package!
With recent versions of PyTorch,
torch.compile()
fails with the error "RuntimeError: aten::copy() Expected a value of type 'Tensor' for argument 'src' but instead found type 'int'." on the line that doesself.cached_freqs_seq_len.copy_(seq_len)
, whereseq_len
is an int.Also it warns about a "Graph break from Tensor.item()" on the line that has
(offset + seq_len) <= self.cached_freqs_seq_len.item()
.This PR fixes both by changing
cached_freqs_seq_len
andcached_scales_seq_len
from singleton int tensors to plain Python integers. Forgive me if I overlooked anything, but it seems to me that there is no benefit of having these values on the GPU?