Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow device to configure conversion to numpy and use of pure_callback #6788

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 8, 2025

Context:

While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set diff_method=None right now and jit the execution end-to-end, trying to jit diff_method=None will cause incomprehensible error messages on non-DQ devices.

We want to forbid differentiation diff_method=None, but keep a way to jit a finite shot execution.

Description of the Change:

In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the ExecutionConfig, convert_to_numpy. If False, then we will not use a pure_callback to convert the parameters to numpy. If True, we use a pure_callback and convert the parameters to numpy.

Benefits:

Speed ups due to being able to jit the entire execution.

image

Possible Drawbacks:

ExecutionConfig gets an addtional property, making it more complicated.

Related GitHub Issues:

Fixes #6054 Fixes #3259 Blocks #6770

Copy link
Contributor

github-actions bot commented Jan 8, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (9350cb2) to head (9c7ea17).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6788   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         476      476           
  Lines       45222    45248   +26     
=======================================
+ Hits        45045    45071   +26     
  Misses        177      177           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

tests/test_qnode.py Outdated Show resolved Hide resolved
Comment on lines 9 to 13
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
* Devices can now configure whether or not the data is converted to numpy enabling `jax.pure_callback` to be used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions on `default.qubit` can now be jitted end-to-end leading to performance improvements, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new option sounds a little confusing to me. What was the problem with the original note? I'll see if I can find a better way to phrase it. Maybe just leave out the reference to jax.pure_callback?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this improvement, I recommend swapping the sentences around. We should always assume a reader is less likely to keep reading -- what is the first thing we want them to takeaway?

In this case, the first sentence should likely speak to what the user-facing improvement is (and how a user can do it, if relevant), followed by the actual implementation change.

Copy link
Contributor

@andrijapau andrijapau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments on my initial pass through.

pennylane/devices/default_qubit.py Outdated Show resolved Hide resolved
Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +155 to +168
def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include "jax" as an interface in the testing?

Also, I'm curious if converting to numpy with adjoint has negative effects on performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we could make adjoint jittable and get some nice speed boosts, but i think that would need to be a follow on task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the jax interface is tested in test_not_convert_to_numpy_with_jax.

@@ -417,7 +417,7 @@ def circuit(state):


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of shots...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using a seed, maybe it's worth reducing the number of shots.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the tolerance is still almost to high... the problem is that this test is now doing finite shot finite differences with float32. Im not sure we can get any accuracy out of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need a finite shot test here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely seems like the wrong file for this type of test...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just going to remove the finite shot version.

tests/workflow/interfaces/execute/test_jax.py Show resolved Hide resolved
@albi3ro albi3ro requested a review from andrijapau January 13, 2025 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants