Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Capture tensor.to() under export. #123732

Closed
wants to merge 1 commit into from

Conversation

zhxchen17
Copy link
Contributor

Summary: We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Differential Revision: D55969674

Copy link

pytorch-bot bot commented Apr 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123732

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 24014a8 with merge base 674e15a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

zhxchen17 added a commit to zhxchen17/pytorch that referenced this pull request Apr 10, 2024
Summary:

We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Differential Revision: D55969674
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh very cool

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 10, 2024
@zhxchen17
Copy link
Contributor Author

test failures look fine to fix. will update this later

@SherlockNoMad SherlockNoMad requested a review from zou3519 April 10, 2024 16:37
@@ -388,7 +388,7 @@ static inline Tensor to_impl(
c10::optional<c10::MemoryFormat> optional_memory_format) {

// fast path
if (to_will_alias(self, dtype, layout, device, copy, optional_memory_format)) {
if (to_will_alias(self, dtype, layout, device, copy, optional_memory_format) && self.const_data_ptr()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for const_data_ptr(), it's not trivial this is for fakeTensor/functionalTensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is a proxy check for under FakeMode?
maybe we can turn this into an explicit check for under FakeMode?

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the semantics of user code to make Tensor.to always copy under torch.export (and also torch.compile, it looks like?). I don't think that's what we want. A test case is the following (replace torch.compile with torch.export too)

@torch.compile
def f(x):
   y = x.clone()
   z = y.to("cpu")
   y.sin_()
   return z

x = torch.randn(3)
out = f(x)
assert torch.allclose(out, x.sin())

@zhxchen17
Copy link
Contributor Author

zhxchen17 commented Apr 10, 2024

This changes the semantics of user code to make Tensor.to always copy under torch.export (and also torch.compile, it looks like?). I don't think that's what we want. A test case is the following (replace torch.compile with torch.export too)

@torch.compile
def f(x):
   y = x.clone()
   z = y.to("cpu")
   y.sin_()
   return z

x = torch.randn(3)
out = f(x)
assert torch.allclose(out, x.sin())

@zou3519 I understand that in this particular case, we want to preserve the old behavior, so I don't think we necessarily want to dispatch to _to_copy always. Instead, we just want to preserve the device conversion semantics. I saw there's actually an aten::to which actually does this. https://fburl.com/code/8ax41xzx Could we dispatch to these ops instead?

@zou3519
Copy link
Contributor

zou3519 commented Apr 11, 2024

Is your eventual goal that you'd to export a graph that is device agnostic and the Tensor.to is one of the things that is preventing this from working?

@zhxchen17
Copy link
Contributor Author

zhxchen17 commented Apr 11, 2024

Is your eventual goal that you'd to export a graph that is device agnostic and the Tensor.to is one of the things that is preventing this from working?

Yes.

Also I want to provide a little bit more context on device issue: export team has concluded in the design meeting that we will bake in device from time to time due to how pt2 works, BUT for case like tensor.to() which is authored by user, we just want to preserve this in the captured graph which means at least 1 op should be captured by torch.export() in this case, otherwise it'd be a big surprise to people who rely on the graph. For device bake ins caused by other reasons e.g. op decomposition and so on, we don't care that much right now and the biggest thing we want to fix is tensor.to(device) from user code.

@zou3519
Copy link
Contributor

zou3519 commented Apr 11, 2024

Our constraints are that we need to preserve the semantics of the user program. This is made complicated by how PT2 IR is functional -- a call to something like aten::to isn't functional and therefore it needs to be decomposed.

One middle ground is, when out = tensor.to(...) is called and is a no-op:

  • IF we detect that it is safe to turn it into a copy, then we can emit a call to aten::_to_copy. It is safe to turn into a copy if out isn't being mutated or returned from the program and tensor isn't mutated. I'm not sure if this is practical enough for your use cases.

@zou3519 zou3519 requested a review from bdhirsh April 11, 2024 19:20
@zhxchen17
Copy link
Contributor Author

zhxchen17 commented Apr 11, 2024

Our constraints are that we need to preserve the semantics of the user program. This is made complicated by how PT2 IR is functional -- a call to something like aten::to isn't functional and therefore it needs to be decomposed.

One middle ground is, when out = tensor.to(...) is called and is a no-op:

  • IF we detect that it is safe to turn it into a copy, then we can emit a call to aten::_to_copy. It is safe to turn into a copy if out isn't being mutated or returned from the program and tensor isn't mutated. I'm not sure if this is practical enough for your use cases.

@zou3519 If we could do the middle ground approach you proposed, it will be helpful as well.

I don't understand why capture the graph as to_dtype_layout will break the constraint of preserving user semantics with the presence of functionalization tho, especially in pre dispatch mode. It seems we can still preserve the aliasing relationship unlike _to_copy here, if we only reason about the pre condition and post condition of this call.

nvm I think I understand it now, sorry.

@tugsbayasgalan
Copy link
Contributor

@bdhirsh, @zhxchen17, and I talked offline. In general, it seems wrong behaviour to support mutation on the result of tensor.to because depending on whether you used CUDA tensor or CPU tensor, the behaviour of export program will be different. Consider following example:

y = x.to("cpu")
y.add_(5)
return x.sin()

In above code, we will mutate input if x is a CPU tensor while we won't if it is a CUDA tensor. As a result, i think it makes sense to just always decompose aten.to to aten._to_copy and ban the mutation on the output of aten.to.

@zou3519
Copy link
Contributor

zou3519 commented Apr 11, 2024

Banning the mutation seems good as long as we can raise an error on it. The tricky thing is being sure that we will actually raise the exception if a mutation happens.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

@zhxchen17
Copy link
Contributor Author

@pytorchbot rebase

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #123732, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55969674

Summary:
Pull Request resolved: pytorch#123732

We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Reviewed By: tugsbayasgalan, angelayi

Differential Revision: D55969674
@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

alat-rights pushed a commit to alat-rights/pytorch that referenced this pull request Apr 26, 2024
Summary: We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Differential Revision: D55969674

Pull Request resolved: pytorch#123732
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
Summary: We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Differential Revision: D55969674

Pull Request resolved: #123732
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants