-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow device to configure conversion to numpy and use of pure_callback
#6788
base: master
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
…I/pennylane into no-interface-boundary
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6788 +/- ##
=======================================
Coverage 99.60% 99.60%
=======================================
Files 476 476
Lines 45222 45248 +26
=======================================
+ Hits 45045 45071 +26
Misses 177 177 ☔ View full report in Codecov by Sentry. |
doc/releases/changelog-dev.md
Outdated
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` | ||
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions | ||
on `default.qubit` can now be jitted end-to-end, even with parameter shift. | ||
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` | |
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions | |
on `default.qubit` can now be jitted end-to-end, even with parameter shift. | |
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | |
* Devices can now configure whether or not the data is converted to numpy enabling `jax.pure_callback` to be used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions on `default.qubit` can now be jitted end-to-end leading to performance improvements, even with parameter shift. | |
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new option sounds a little confusing to me. What was the problem with the original note? I'll see if I can find a better way to phrase it. Maybe just leave out the reference to jax.pure_callback
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this improvement, I recommend swapping the sentences around. We should always assume a reader is less likely to keep reading -- what is the first thing we want them to takeaway?
In this case, the first sentence should likely speak to what the user-facing improvement is (and how a user can do it, if relevant), followed by the actual implementation change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments on my initial pass through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def test_convert_to_numpy_with_adjoint(self): | ||
"""Test that we will convert to numpy with adjoint.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy | ||
|
||
@pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) | ||
def test_convert_to_numpy_non_jax(self, interface): | ||
"""Test that other interfaces are still converted to numpy.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you include "jax" as an interface in the testing?
Also, I'm curious if converting to numpy with adjoint has negative effects on performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we could make adjoint jittable and get some nice speed boosts, but i think that would need to be a follow on task.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the jax interface is tested in test_not_convert_to_numpy_with_jax
.
@@ -417,7 +417,7 @@ def circuit(state): | |||
|
|||
|
|||
@pytest.mark.jax | |||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)]) | |||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot of shots...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're using a seed, maybe it's worth reducing the number of shots.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the tolerance is still almost to high... the problem is that this test is now doing finite shot finite differences with float32. Im not sure we can get any accuracy out of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we even need a finite shot test here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely seems like the wrong file for this type of test...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just going to remove the finite shot version.
Context:
While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set
diff_method=None
right now and jit the execution end-to-end, trying to jitdiff_method=None
will cause incomprehensible error messages on non-DQ devices.We want to forbid differentiation
diff_method=None
, but keep a way to jit a finite shot execution.Description of the Change:
In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the
ExecutionConfig
,convert_to_numpy
. IfFalse
, then we will not use apure_callback
to convert the parameters to numpy. IfTrue
, we use apure_callback
and convert the parameters to numpy.Benefits:
Speed ups due to being able to jit the entire execution.
Possible Drawbacks:
ExecutionConfig
gets an addtional property, making it more complicated.Related GitHub Issues:
Fixes #6054 Fixes #3259 Blocks #6770