diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 777988ddb44..331b59add35 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -36,6 +36,9 @@
* Added `ops.functions.assert_valid` for checking if an `Operator` class is defined correctly.
[(#4764)](https://github.com/PennyLaneAI/pennylane/pull/4764)
+* Simplified the logic for re-arranging states before returning.
+ [(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817)
+
Breaking changes 💔
* The `prep` keyword argument has been removed from `QuantumScript` and `QuantumTape`.
@@ -119,6 +122,7 @@
* Parametrized circuits whose operators do not act on all wires return pennylane tensors as
expected, instead of numpy arrays.
[(#4811)](https://github.com/PennyLaneAI/pennylane/pull/4811)
+ [(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817)
Contributors ✍️
diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py
index 7a4bb6c5732..670b2a77a96 100644
--- a/pennylane/devices/qubit/simulate.py
+++ b/pennylane/devices/qubit/simulate.py
@@ -18,7 +18,6 @@
import pennylane as qml
from pennylane.typing import Result
-from pennylane.wires import Wires
from .initialize_state import create_initial_state
from .apply_operation import apply_operation
@@ -61,44 +60,6 @@ def __init__(self, shots=None):
self._frozen = True
-def expand_state_over_wires(state, state_wires, all_wires, is_state_batched):
- """
- Expand and re-order a state given some initial and target wire orders, setting
- all additional wires to the 0 state.
-
- Args:
- state (~pennylane.typing.TensorLike): The state to re-order and expand
- state_wires (.Wires): The wire order of the inputted state
- all_wires (.Wires): The desired wire order
- is_state_batched (bool): Whether the state has a batch dimension or not
-
- Returns:
- TensorLike: The state in the new desired size and order
- """
- interface = qml.math.get_interface(state)
- pad_width = 2 ** len(all_wires) - 2 ** len(state_wires)
- pad = (pad_width, 0) if interface == "torch" else (0, pad_width)
- shape = (2,) * len(all_wires)
- if is_state_batched:
- pad = ((0, 0), pad)
- batch_size = qml.math.shape(state)[0]
- shape = (batch_size,) + shape
- state = qml.math.reshape(state, (batch_size, -1))
- else:
- pad = (pad,)
- state = qml.math.flatten(state)
-
- state = qml.math.pad(state, pad, mode="constant", like=interface)
- state = qml.math.reshape(state, shape)
-
- # re-order
- new_wire_order = Wires.unique_wires([all_wires, state_wires]) + state_wires
- desired_axes = [new_wire_order.index(w) for w in all_wires]
- if is_state_batched:
- desired_axes = [0] + [i + 1 for i in desired_axes]
- return qml.math.transpose(state, desired_axes)
-
-
def _postselection_postprocess(state, is_state_batched, shots):
"""Update state after projector is applied."""
if is_state_batched:
@@ -170,13 +131,10 @@ def get_final_state(circuit, debugger=None, interface=None):
# new state is batched if i) the old state is batched, or ii) the new op adds a batch dim
is_state_batched = is_state_batched or op.batch_size is not None
- if set(circuit.op_wires) < set(circuit.wires):
- state = expand_state_over_wires(
- state,
- Wires(range(len(circuit.op_wires))),
- Wires(range(circuit.num_wires)),
- is_state_batched,
- )
+ for _ in range(len(circuit.wires) - len(circuit.op_wires)):
+ # if any measured wires are not operated on, we pad the state with zeros.
+ # We know they belong at the end because the circuit is in standard wire-order
+ state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1)
return state, is_state_batched
diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py
index 78dda1cc47f..4793213a47e 100644
--- a/pennylane/measurements/state.py
+++ b/pennylane/measurements/state.py
@@ -160,33 +160,22 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
if not wires or wire_order == wires:
return qml.math.cast(state, "complex128")
- if not wires.contains_wires(wire_order):
+ if set(wires) != set(wire_order):
raise WireError(
- f"Unexpected wires {set(wire_order) - set(wires)} found in wire order. Expected wire order to be a subset of {wires}"
+ f"Unexpected unique wires {Wires.unique_wires([wires, wire_order])} found. "
+ f"Expected wire order {wire_order} to be a rearrangement of {wires}"
)
- # pad with zeros, put existing wires last
- is_state_batched = qml.math.ndim(state) == 2
- pad_width = 2 ** len(wires) - 2 ** len(wire_order)
- pad = (pad_width, 0) if qml.math.get_interface(state) == "torch" else (0, pad_width)
shape = (2,) * len(wires)
flat_shape = (2 ** len(wires),)
- if is_state_batched:
+ desired_axes = [wire_order.index(w) for w in wires]
+ if qml.math.ndim(state) == 2: # batched state
batch_size = qml.math.shape(state)[0]
- pad = ((0, 0), pad)
shape = (batch_size,) + shape
flat_shape = (batch_size,) + flat_shape
- else:
- pad = (pad,)
+ desired_axes = [0] + [i + 1 for i in desired_axes]
- state = qml.math.pad(state, pad, mode="constant")
state = qml.math.reshape(state, shape)
-
- # re-order
- new_wire_order = Wires.unique_wires([wires, wire_order]) + wire_order
- desired_axes = [new_wire_order.index(w) for w in wires]
- if is_state_batched:
- desired_axes = [0] + [i + 1 for i in desired_axes]
state = qml.math.transpose(state, desired_axes)
state = qml.math.reshape(state, flat_shape)
return qml.math.cast(state, "complex128")
diff --git a/pennylane/tape/qscript.py b/pennylane/tape/qscript.py
index 695ad65c8b4..e18f5e11a33 100644
--- a/pennylane/tape/qscript.py
+++ b/pennylane/tape/qscript.py
@@ -1187,19 +1187,19 @@ def map_to_standard_wires(self):
**Example:**
>>> circuit = qml.tape.QuantumScript([qml.PauliX("a")], [qml.expval(qml.PauliZ("b"))])
- >>> map_circuit_to_standard_wires(circuit).circuit
+ >>> circuit.map_to_standard_wires().circuit
[PauliX(wires=[0]), expval(PauliZ(wires=[1]))]
If any measured wires are not in any operations, they will be mapped last:
>>> circuit = qml.tape.QuantumScript([qml.PauliX(1)], [qml.probs(wires=[0, 1])])
- >>> qml.devices.qubit.map_circuit_to_standard_wires(circuit).circuit
+ >>> circuit.map_to_standard_wires().circuit
[PauliX(wires=[0]), probs(wires=[1, 0])]
If no wire-mapping is needed, then the returned circuit *is* the inputted circuit:
>>> circuit = qml.tape.QuantumScript([qml.PauliX(0)], [qml.expval(qml.PauliZ(1))])
- >>> qml.devices.qubit.map_circuit_to_standard_wires(circuit) is circuit
+ >>> circuit.map_to_standard_wires() is circuit
True
"""
diff --git a/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py
index c193dd3feeb..b6dff017aa3 100644
--- a/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py
+++ b/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py
@@ -30,9 +30,9 @@
@pytest.fixture(autouse=True)
def run_before_and_after_tests():
- torch.set_default_tensor_type(torch.DoubleTensor)
+ torch.set_default_dtype(torch.float64)
yield
- torch.set_default_tensor_type(torch.FloatTensor)
+ torch.set_default_dtype(torch.float32)
# pylint: disable=too-few-public-methods
diff --git a/tests/measurements/test_state.py b/tests/measurements/test_state.py
index ab306b5b01e..8a726072904 100644
--- a/tests/measurements/test_state.py
+++ b/tests/measurements/test_state.py
@@ -72,119 +72,62 @@ def test_reorder_state(self, interface, wires, wire_order):
assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert qml.math.get_interface(ket) == interface
- @pytest.mark.parametrize(
- "mp_wires, expected_state",
- [
- ([0, 1, 2], [1, 0, 2, 0, 3, 0, 4, 0]),
- ([2, 0, 1], [1, 2, 3, 4, 0, 0, 0, 0]),
- ([1, 0, 2], [1, 0, 3, 0, 2, 0, 4, 0]),
- ([1, 2, 0], [1, 3, 0, 0, 2, 4, 0, 0]),
- ],
- )
- @pytest.mark.parametrize("custom_wire_labels", [False, True])
- def test_expand_state_over_wires(self, mp_wires, expected_state, custom_wire_labels):
- """Test the expanded state is correctly ordered with extra wires being zero."""
- wire_order = [0, 1]
- if custom_wire_labels:
- # non-lexicographical-ordered
- wire_map = {0: "b", 1: "c", 2: "a"}
- mp_wires = [wire_map[w] for w in mp_wires]
- wire_order = ["b", "c"]
- mp = StateMP(wires=mp_wires)
- ket = np.arange(1, 5)
- result = mp.process_state(ket, wire_order=Wires(wire_order))
- assert qml.math.get_dtype_name(result) == "complex128"
- assert np.array_equal(result, expected_state)
-
@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
- def test_expand_state_all_interfaces(self, interface):
- """Test that expanding the state over wires preserves interface."""
- mp = StateMP(wires=[4, 2, 0, 1])
- ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface)
- result = mp.process_state(ket, wire_order=Wires([1, 2]))
- reshaped = qml.math.reshape(result, (2, 2, 2, 2))
- assert qml.math.all(reshaped[1, :, 1, :] == 0)
- assert qml.math.allclose(reshaped[0, :, 0, :], np.array([[0.48j, -0.64j], [0.48, 0.36]]))
- if interface != "autograd":
- # autograd.numpy.pad drops pennylane tensor for some reason
- assert qml.math.get_interface(result) == interface
+ def test_reorder_state_three_wires(self, interface):
+ """Test that a 3-qubit state can be re-ordered."""
+ input_wires = Wires([2, 0, 1])
+ output_wires = Wires([1, 2, 0])
+ ket = qml.math.arange(8, like=interface)
+ expected = np.array([0, 2, 4, 6, 1, 3, 5, 7])
+ result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires))
+ assert qml.math.allclose(result, expected)
+ assert qml.math.get_interface(ket) == interface
@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
- def test_expand_state_batched_all_interfaces(self, interface):
- """Test that expanding the state over wires preserves interface."""
- mp = StateMP(wires=[4, 2, 0, 1])
- ket = qml.math.array(
- [
- [0.48j, 0.48, -0.64j, 0.36],
- [0.3, 0.4, 0.5, 1 / np.sqrt(2)],
- [-0.3, -0.4, -0.5, -1 / np.sqrt(2)],
- ],
- like=interface,
- )
- result = mp.process_state(ket, wire_order=Wires([1, 2]))
- assert qml.math.shape(result) == (3, 16)
- reshaped = qml.math.reshape(result, (3, 2, 2, 2, 2))
- assert qml.math.all(reshaped[:, 1, :, 1, :] == 0)
- assert qml.math.allclose(
- reshaped[:, 0, :, 0, :],
- np.array(
- [
- [[0.48j, -0.64j], [0.48, 0.36]],
- [[0.3, 0.5], [0.4, 1 / np.sqrt(2)]],
- [[-0.3, -0.5], [-0.4, -1 / np.sqrt(2)]],
- ],
- ),
- )
- if interface != "autograd":
- # autograd.numpy.pad drops pennylane tensor for some reason
- assert qml.math.get_interface(result) == interface
+ def test_reorder_state_three_wires_batched(self, interface):
+ """Test that a batched, 3-qubit state can be re-ordered."""
+ input_wires = Wires([2, 0, 1])
+ output_wires = Wires([1, 2, 0])
+ ket = qml.math.reshape(qml.math.arange(16, like=interface), (2, 8))
+ expected = np.array([0, 2, 4, 6, 1, 3, 5, 7])
+ expected = np.array([expected, expected + 8])
+ result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires))
+ assert qml.math.allclose(result, expected)
+ assert qml.math.shape(result) == (2, 8)
+ assert qml.math.get_interface(ket) == interface
@pytest.mark.jax
- @pytest.mark.parametrize(
- "wires,expected",
- [
- ([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])),
- ([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])),
- ],
- )
- def test_state_jax_jit(self, wires, expected):
- """Test that re-ordering and expanding works with jax-jit."""
+ def test_state_jax_jit(self):
+ """Test that re-ordering works with jax-jit."""
import jax
@jax.jit
def get_state(ket):
- return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1]))
+ return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1]))
result = get_state(jax.numpy.array([0.48j, 0.48, -0.64j, 0.36]))
- assert qml.math.allclose(result, expected)
+ assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert isinstance(result, jax.Array)
@pytest.mark.tf
- @pytest.mark.parametrize(
- "wires,expected",
- [
- ([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])),
- ([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])),
- ],
- )
- def test_state_tf_function(self, wires, expected):
- """Test that re-ordering and expanding works with tf.function."""
+ def test_state_tf_function(self):
+ """Test that re-ordering works with tf.function."""
import tensorflow as tf
@tf.function
def get_state(ket):
- return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1]))
+ return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1]))
result = get_state(tf.Variable([0.48j, 0.48, -0.64j, 0.36]))
- assert qml.math.allclose(result, expected)
+ assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert isinstance(result, tf.Tensor)
def test_wire_ordering_error(self):
"""Test that a wire order error is raised when unknown wires are given."""
- with pytest.raises(WireError, match=r"Unexpected wires \{2\} found in wire order"):
- StateMP(wires=[0, 1]).process_state([1, 0], wire_order=[2])
+ with pytest.raises(WireError, match=r"Unexpected unique wires found"):
+ StateMP(wires=[0, 1]).process_state([1, 0], wire_order=Wires(2))
class TestDensityMatrixMP: