Skip to content

Commit

Permalink
simplify state reordering logic (#4817)
Browse files Browse the repository at this point in the history
**Context:**
I wrote the same function twice, differing only by state flattening, to
get the DQ upgrade done. It's starting to cause trouble.

**Description of the Change:**
Greatly simplified the state re-arrangement logic. There used to be a
whole mess of things happening, but now things are much more
straightforward.
1. `simulate` first puts things in our "standard" order, and this means
that if any measured wires are not also operator wires, they are put to
the _end_ of our tape wires. Therefore, for each measured-only wire, we
just have to stack a `zeros_like(state)` to the last axis of our final
state! `simulate` never tried to transpose wires back to a different
ordering, so that was always wasted work.
2. `StateMP.process_state` _always_ receives the full state, and never
needed to pad. No other device has done this optimization (the function
used to literally just `return state` before DQ2 migration), and
`simulate` already ensures that the final state has all wires in it -
they just might be out of order. The only thing we might need from
`process_state` is a transposition to the correct wire order. The
inputted `wire_order` _should_ always be `range(len(wires))`, but
whatever, we don't need to assume that.

I'll paint a picture for a normal scenario:

```python
@qml.qnode(qml.device("default.qubit", wires=3))
def circuit(x):
    qml.RX(x, 0)
    qml.CNOT([0, 2])
    return qml.state()
```

What happens with this QNode?
1. Device preprocessing sticks the device wires (`[0, 1, 2]`) onto the
`StateMP`
2. `simulate` maps the wires to our standard order. I'll demonstrate
(with `probs` so I can specify wires):

```pycon
>>> qs = qml.tape.QuantumScript([qml.RX(1.1, 0), qml.CNOT([0, 2])], [qml.probs(wires=[0, 1, 2])])
>>> qs.map_to_standard_wires().circuit
[RX(1.1, wires=[0]), CNOT(wires=[0, 1]), probs(wires=[0, 2, 1])]
```

3. Operate on the 2-qubit state, then stack another `[[0, 0], [0, 0]]`
on the end of it (wire "1")
4. `StateMP(wires=[0, 1, 2]).process_state(state, wire_order=[0, 2, 1])`
transposes the result to the correct order

I also changed the torch tests to stop using a deprecated setter for
default float types.

**Benefits:**
Duplicate code is cleaned up, existing code is simplified, no
unnecessary call to transpose.

**Possible Drawbacks:**
- Have to call `qml.math.stack` for every wire that was not operated on.
Hopefully this is usually not a lot, and it's not that costly anyway
- functions now do less than they used to (I see this as a perk - they
now do _exactly_ what they're supposed to)
  • Loading branch information
timmysilv authored Nov 16, 2023
1 parent 93a8716 commit 47e74e1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 155 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
* Added `ops.functions.assert_valid` for checking if an `Operator` class is defined correctly.
[(#4764)](https://github.com/PennyLaneAI/pennylane/pull/4764)

* Simplified the logic for re-arranging states before returning.
[(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817)

<h3>Breaking changes 💔</h3>

* The `prep` keyword argument has been removed from `QuantumScript` and `QuantumTape`.
Expand Down Expand Up @@ -119,6 +122,7 @@
* Parametrized circuits whose operators do not act on all wires return pennylane tensors as
expected, instead of numpy arrays.
[(#4811)](https://github.com/PennyLaneAI/pennylane/pull/4811)
[(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817)

<h3>Contributors ✍️</h3>

Expand Down
50 changes: 4 additions & 46 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import pennylane as qml
from pennylane.typing import Result
from pennylane.wires import Wires

from .initialize_state import create_initial_state
from .apply_operation import apply_operation
Expand Down Expand Up @@ -61,44 +60,6 @@ def __init__(self, shots=None):
self._frozen = True


def expand_state_over_wires(state, state_wires, all_wires, is_state_batched):
"""
Expand and re-order a state given some initial and target wire orders, setting
all additional wires to the 0 state.
Args:
state (~pennylane.typing.TensorLike): The state to re-order and expand
state_wires (.Wires): The wire order of the inputted state
all_wires (.Wires): The desired wire order
is_state_batched (bool): Whether the state has a batch dimension or not
Returns:
TensorLike: The state in the new desired size and order
"""
interface = qml.math.get_interface(state)
pad_width = 2 ** len(all_wires) - 2 ** len(state_wires)
pad = (pad_width, 0) if interface == "torch" else (0, pad_width)
shape = (2,) * len(all_wires)
if is_state_batched:
pad = ((0, 0), pad)
batch_size = qml.math.shape(state)[0]
shape = (batch_size,) + shape
state = qml.math.reshape(state, (batch_size, -1))
else:
pad = (pad,)
state = qml.math.flatten(state)

state = qml.math.pad(state, pad, mode="constant", like=interface)
state = qml.math.reshape(state, shape)

# re-order
new_wire_order = Wires.unique_wires([all_wires, state_wires]) + state_wires
desired_axes = [new_wire_order.index(w) for w in all_wires]
if is_state_batched:
desired_axes = [0] + [i + 1 for i in desired_axes]
return qml.math.transpose(state, desired_axes)


def _postselection_postprocess(state, is_state_batched, shots):
"""Update state after projector is applied."""
if is_state_batched:
Expand Down Expand Up @@ -170,13 +131,10 @@ def get_final_state(circuit, debugger=None, interface=None):
# new state is batched if i) the old state is batched, or ii) the new op adds a batch dim
is_state_batched = is_state_batched or op.batch_size is not None

if set(circuit.op_wires) < set(circuit.wires):
state = expand_state_over_wires(
state,
Wires(range(len(circuit.op_wires))),
Wires(range(circuit.num_wires)),
is_state_batched,
)
for _ in range(len(circuit.wires) - len(circuit.op_wires)):
# if any measured wires are not operated on, we pad the state with zeros.
# We know they belong at the end because the circuit is in standard wire-order
state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1)

return state, is_state_batched

Expand Down
23 changes: 6 additions & 17 deletions pennylane/measurements/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,22 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
if not wires or wire_order == wires:
return qml.math.cast(state, "complex128")

if not wires.contains_wires(wire_order):
if set(wires) != set(wire_order):
raise WireError(
f"Unexpected wires {set(wire_order) - set(wires)} found in wire order. Expected wire order to be a subset of {wires}"
f"Unexpected unique wires {Wires.unique_wires([wires, wire_order])} found. "
f"Expected wire order {wire_order} to be a rearrangement of {wires}"
)

# pad with zeros, put existing wires last
is_state_batched = qml.math.ndim(state) == 2
pad_width = 2 ** len(wires) - 2 ** len(wire_order)
pad = (pad_width, 0) if qml.math.get_interface(state) == "torch" else (0, pad_width)
shape = (2,) * len(wires)
flat_shape = (2 ** len(wires),)
if is_state_batched:
desired_axes = [wire_order.index(w) for w in wires]
if qml.math.ndim(state) == 2: # batched state
batch_size = qml.math.shape(state)[0]
pad = ((0, 0), pad)
shape = (batch_size,) + shape
flat_shape = (batch_size,) + flat_shape
else:
pad = (pad,)
desired_axes = [0] + [i + 1 for i in desired_axes]

state = qml.math.pad(state, pad, mode="constant")
state = qml.math.reshape(state, shape)

# re-order
new_wire_order = Wires.unique_wires([wires, wire_order]) + wire_order
desired_axes = [new_wire_order.index(w) for w in wires]
if is_state_batched:
desired_axes = [0] + [i + 1 for i in desired_axes]
state = qml.math.transpose(state, desired_axes)
state = qml.math.reshape(state, flat_shape)
return qml.math.cast(state, "complex128")
Expand Down
6 changes: 3 additions & 3 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,19 +1187,19 @@ def map_to_standard_wires(self):
**Example:**
>>> circuit = qml.tape.QuantumScript([qml.PauliX("a")], [qml.expval(qml.PauliZ("b"))])
>>> map_circuit_to_standard_wires(circuit).circuit
>>> circuit.map_to_standard_wires().circuit
[PauliX(wires=[0]), expval(PauliZ(wires=[1]))]
If any measured wires are not in any operations, they will be mapped last:
>>> circuit = qml.tape.QuantumScript([qml.PauliX(1)], [qml.probs(wires=[0, 1])])
>>> qml.devices.qubit.map_circuit_to_standard_wires(circuit).circuit
>>> circuit.map_to_standard_wires().circuit
[PauliX(wires=[0]), probs(wires=[1, 0])]
If no wire-mapping is needed, then the returned circuit *is* the inputted circuit:
>>> circuit = qml.tape.QuantumScript([qml.PauliX(0)], [qml.expval(qml.PauliZ(1))])
>>> qml.devices.qubit.map_circuit_to_standard_wires(circuit) is circuit
>>> circuit.map_to_standard_wires() is circuit
True
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

@pytest.fixture(autouse=True)
def run_before_and_after_tests():
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
yield
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.float32)


# pylint: disable=too-few-public-methods
Expand Down
117 changes: 30 additions & 87 deletions tests/measurements/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,119 +72,62 @@ def test_reorder_state(self, interface, wires, wire_order):
assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert qml.math.get_interface(ket) == interface

@pytest.mark.parametrize(
"mp_wires, expected_state",
[
([0, 1, 2], [1, 0, 2, 0, 3, 0, 4, 0]),
([2, 0, 1], [1, 2, 3, 4, 0, 0, 0, 0]),
([1, 0, 2], [1, 0, 3, 0, 2, 0, 4, 0]),
([1, 2, 0], [1, 3, 0, 0, 2, 4, 0, 0]),
],
)
@pytest.mark.parametrize("custom_wire_labels", [False, True])
def test_expand_state_over_wires(self, mp_wires, expected_state, custom_wire_labels):
"""Test the expanded state is correctly ordered with extra wires being zero."""
wire_order = [0, 1]
if custom_wire_labels:
# non-lexicographical-ordered
wire_map = {0: "b", 1: "c", 2: "a"}
mp_wires = [wire_map[w] for w in mp_wires]
wire_order = ["b", "c"]
mp = StateMP(wires=mp_wires)
ket = np.arange(1, 5)
result = mp.process_state(ket, wire_order=Wires(wire_order))
assert qml.math.get_dtype_name(result) == "complex128"
assert np.array_equal(result, expected_state)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
def test_expand_state_all_interfaces(self, interface):
"""Test that expanding the state over wires preserves interface."""
mp = StateMP(wires=[4, 2, 0, 1])
ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface)
result = mp.process_state(ket, wire_order=Wires([1, 2]))
reshaped = qml.math.reshape(result, (2, 2, 2, 2))
assert qml.math.all(reshaped[1, :, 1, :] == 0)
assert qml.math.allclose(reshaped[0, :, 0, :], np.array([[0.48j, -0.64j], [0.48, 0.36]]))
if interface != "autograd":
# autograd.numpy.pad drops pennylane tensor for some reason
assert qml.math.get_interface(result) == interface
def test_reorder_state_three_wires(self, interface):
"""Test that a 3-qubit state can be re-ordered."""
input_wires = Wires([2, 0, 1])
output_wires = Wires([1, 2, 0])
ket = qml.math.arange(8, like=interface)
expected = np.array([0, 2, 4, 6, 1, 3, 5, 7])
result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires))
assert qml.math.allclose(result, expected)
assert qml.math.get_interface(ket) == interface

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
def test_expand_state_batched_all_interfaces(self, interface):
"""Test that expanding the state over wires preserves interface."""
mp = StateMP(wires=[4, 2, 0, 1])
ket = qml.math.array(
[
[0.48j, 0.48, -0.64j, 0.36],
[0.3, 0.4, 0.5, 1 / np.sqrt(2)],
[-0.3, -0.4, -0.5, -1 / np.sqrt(2)],
],
like=interface,
)
result = mp.process_state(ket, wire_order=Wires([1, 2]))
assert qml.math.shape(result) == (3, 16)
reshaped = qml.math.reshape(result, (3, 2, 2, 2, 2))
assert qml.math.all(reshaped[:, 1, :, 1, :] == 0)
assert qml.math.allclose(
reshaped[:, 0, :, 0, :],
np.array(
[
[[0.48j, -0.64j], [0.48, 0.36]],
[[0.3, 0.5], [0.4, 1 / np.sqrt(2)]],
[[-0.3, -0.5], [-0.4, -1 / np.sqrt(2)]],
],
),
)
if interface != "autograd":
# autograd.numpy.pad drops pennylane tensor for some reason
assert qml.math.get_interface(result) == interface
def test_reorder_state_three_wires_batched(self, interface):
"""Test that a batched, 3-qubit state can be re-ordered."""
input_wires = Wires([2, 0, 1])
output_wires = Wires([1, 2, 0])
ket = qml.math.reshape(qml.math.arange(16, like=interface), (2, 8))
expected = np.array([0, 2, 4, 6, 1, 3, 5, 7])
expected = np.array([expected, expected + 8])
result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires))
assert qml.math.allclose(result, expected)
assert qml.math.shape(result) == (2, 8)
assert qml.math.get_interface(ket) == interface

@pytest.mark.jax
@pytest.mark.parametrize(
"wires,expected",
[
([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])),
([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])),
],
)
def test_state_jax_jit(self, wires, expected):
"""Test that re-ordering and expanding works with jax-jit."""
def test_state_jax_jit(self):
"""Test that re-ordering works with jax-jit."""
import jax

@jax.jit
def get_state(ket):
return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1]))
return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1]))

result = get_state(jax.numpy.array([0.48j, 0.48, -0.64j, 0.36]))
assert qml.math.allclose(result, expected)
assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert isinstance(result, jax.Array)

@pytest.mark.tf
@pytest.mark.parametrize(
"wires,expected",
[
([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])),
([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])),
],
)
def test_state_tf_function(self, wires, expected):
"""Test that re-ordering and expanding works with tf.function."""
def test_state_tf_function(self):
"""Test that re-ordering works with tf.function."""
import tensorflow as tf

@tf.function
def get_state(ket):
return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1]))
return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1]))

result = get_state(tf.Variable([0.48j, 0.48, -0.64j, 0.36]))
assert qml.math.allclose(result, expected)
assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert isinstance(result, tf.Tensor)

def test_wire_ordering_error(self):
"""Test that a wire order error is raised when unknown wires are given."""
with pytest.raises(WireError, match=r"Unexpected wires \{2\} found in wire order"):
StateMP(wires=[0, 1]).process_state([1, 0], wire_order=[2])
with pytest.raises(WireError, match=r"Unexpected unique wires <Wires = \[0, 1, 2\]> found"):
StateMP(wires=[0, 1]).process_state([1, 0], wire_order=Wires(2))


class TestDensityMatrixMP:
Expand Down

0 comments on commit 47e74e1

Please sign in to comment.