diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 777988ddb44..331b59add35 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -36,6 +36,9 @@ * Added `ops.functions.assert_valid` for checking if an `Operator` class is defined correctly. [(#4764)](https://github.com/PennyLaneAI/pennylane/pull/4764) +* Simplified the logic for re-arranging states before returning. + [(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817) +

Breaking changes 💔

* The `prep` keyword argument has been removed from `QuantumScript` and `QuantumTape`. @@ -119,6 +122,7 @@ * Parametrized circuits whose operators do not act on all wires return pennylane tensors as expected, instead of numpy arrays. [(#4811)](https://github.com/PennyLaneAI/pennylane/pull/4811) + [(#4817)](https://github.com/PennyLaneAI/pennylane/pull/4817)

Contributors ✍️

diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 7a4bb6c5732..670b2a77a96 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -18,7 +18,6 @@ import pennylane as qml from pennylane.typing import Result -from pennylane.wires import Wires from .initialize_state import create_initial_state from .apply_operation import apply_operation @@ -61,44 +60,6 @@ def __init__(self, shots=None): self._frozen = True -def expand_state_over_wires(state, state_wires, all_wires, is_state_batched): - """ - Expand and re-order a state given some initial and target wire orders, setting - all additional wires to the 0 state. - - Args: - state (~pennylane.typing.TensorLike): The state to re-order and expand - state_wires (.Wires): The wire order of the inputted state - all_wires (.Wires): The desired wire order - is_state_batched (bool): Whether the state has a batch dimension or not - - Returns: - TensorLike: The state in the new desired size and order - """ - interface = qml.math.get_interface(state) - pad_width = 2 ** len(all_wires) - 2 ** len(state_wires) - pad = (pad_width, 0) if interface == "torch" else (0, pad_width) - shape = (2,) * len(all_wires) - if is_state_batched: - pad = ((0, 0), pad) - batch_size = qml.math.shape(state)[0] - shape = (batch_size,) + shape - state = qml.math.reshape(state, (batch_size, -1)) - else: - pad = (pad,) - state = qml.math.flatten(state) - - state = qml.math.pad(state, pad, mode="constant", like=interface) - state = qml.math.reshape(state, shape) - - # re-order - new_wire_order = Wires.unique_wires([all_wires, state_wires]) + state_wires - desired_axes = [new_wire_order.index(w) for w in all_wires] - if is_state_batched: - desired_axes = [0] + [i + 1 for i in desired_axes] - return qml.math.transpose(state, desired_axes) - - def _postselection_postprocess(state, is_state_batched, shots): """Update state after projector is applied.""" if is_state_batched: @@ -170,13 +131,10 @@ def get_final_state(circuit, debugger=None, interface=None): # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim is_state_batched = is_state_batched or op.batch_size is not None - if set(circuit.op_wires) < set(circuit.wires): - state = expand_state_over_wires( - state, - Wires(range(len(circuit.op_wires))), - Wires(range(circuit.num_wires)), - is_state_batched, - ) + for _ in range(len(circuit.wires) - len(circuit.op_wires)): + # if any measured wires are not operated on, we pad the state with zeros. + # We know they belong at the end because the circuit is in standard wire-order + state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1) return state, is_state_batched diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py index 78dda1cc47f..4793213a47e 100644 --- a/pennylane/measurements/state.py +++ b/pennylane/measurements/state.py @@ -160,33 +160,22 @@ def process_state(self, state: Sequence[complex], wire_order: Wires): if not wires or wire_order == wires: return qml.math.cast(state, "complex128") - if not wires.contains_wires(wire_order): + if set(wires) != set(wire_order): raise WireError( - f"Unexpected wires {set(wire_order) - set(wires)} found in wire order. Expected wire order to be a subset of {wires}" + f"Unexpected unique wires {Wires.unique_wires([wires, wire_order])} found. " + f"Expected wire order {wire_order} to be a rearrangement of {wires}" ) - # pad with zeros, put existing wires last - is_state_batched = qml.math.ndim(state) == 2 - pad_width = 2 ** len(wires) - 2 ** len(wire_order) - pad = (pad_width, 0) if qml.math.get_interface(state) == "torch" else (0, pad_width) shape = (2,) * len(wires) flat_shape = (2 ** len(wires),) - if is_state_batched: + desired_axes = [wire_order.index(w) for w in wires] + if qml.math.ndim(state) == 2: # batched state batch_size = qml.math.shape(state)[0] - pad = ((0, 0), pad) shape = (batch_size,) + shape flat_shape = (batch_size,) + flat_shape - else: - pad = (pad,) + desired_axes = [0] + [i + 1 for i in desired_axes] - state = qml.math.pad(state, pad, mode="constant") state = qml.math.reshape(state, shape) - - # re-order - new_wire_order = Wires.unique_wires([wires, wire_order]) + wire_order - desired_axes = [new_wire_order.index(w) for w in wires] - if is_state_batched: - desired_axes = [0] + [i + 1 for i in desired_axes] state = qml.math.transpose(state, desired_axes) state = qml.math.reshape(state, flat_shape) return qml.math.cast(state, "complex128") diff --git a/pennylane/tape/qscript.py b/pennylane/tape/qscript.py index 695ad65c8b4..e18f5e11a33 100644 --- a/pennylane/tape/qscript.py +++ b/pennylane/tape/qscript.py @@ -1187,19 +1187,19 @@ def map_to_standard_wires(self): **Example:** >>> circuit = qml.tape.QuantumScript([qml.PauliX("a")], [qml.expval(qml.PauliZ("b"))]) - >>> map_circuit_to_standard_wires(circuit).circuit + >>> circuit.map_to_standard_wires().circuit [PauliX(wires=[0]), expval(PauliZ(wires=[1]))] If any measured wires are not in any operations, they will be mapped last: >>> circuit = qml.tape.QuantumScript([qml.PauliX(1)], [qml.probs(wires=[0, 1])]) - >>> qml.devices.qubit.map_circuit_to_standard_wires(circuit).circuit + >>> circuit.map_to_standard_wires().circuit [PauliX(wires=[0]), probs(wires=[1, 0])] If no wire-mapping is needed, then the returned circuit *is* the inputted circuit: >>> circuit = qml.tape.QuantumScript([qml.PauliX(0)], [qml.expval(qml.PauliZ(1))]) - >>> qml.devices.qubit.map_circuit_to_standard_wires(circuit) is circuit + >>> circuit.map_to_standard_wires() is circuit True """ diff --git a/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py index c193dd3feeb..b6dff017aa3 100644 --- a/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_torch_default_qubit_2.py @@ -30,9 +30,9 @@ @pytest.fixture(autouse=True) def run_before_and_after_tests(): - torch.set_default_tensor_type(torch.DoubleTensor) + torch.set_default_dtype(torch.float64) yield - torch.set_default_tensor_type(torch.FloatTensor) + torch.set_default_dtype(torch.float32) # pylint: disable=too-few-public-methods diff --git a/tests/measurements/test_state.py b/tests/measurements/test_state.py index ab306b5b01e..8a726072904 100644 --- a/tests/measurements/test_state.py +++ b/tests/measurements/test_state.py @@ -72,119 +72,62 @@ def test_reorder_state(self, interface, wires, wire_order): assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36])) assert qml.math.get_interface(ket) == interface - @pytest.mark.parametrize( - "mp_wires, expected_state", - [ - ([0, 1, 2], [1, 0, 2, 0, 3, 0, 4, 0]), - ([2, 0, 1], [1, 2, 3, 4, 0, 0, 0, 0]), - ([1, 0, 2], [1, 0, 3, 0, 2, 0, 4, 0]), - ([1, 2, 0], [1, 3, 0, 0, 2, 4, 0, 0]), - ], - ) - @pytest.mark.parametrize("custom_wire_labels", [False, True]) - def test_expand_state_over_wires(self, mp_wires, expected_state, custom_wire_labels): - """Test the expanded state is correctly ordered with extra wires being zero.""" - wire_order = [0, 1] - if custom_wire_labels: - # non-lexicographical-ordered - wire_map = {0: "b", 1: "c", 2: "a"} - mp_wires = [wire_map[w] for w in mp_wires] - wire_order = ["b", "c"] - mp = StateMP(wires=mp_wires) - ket = np.arange(1, 5) - result = mp.process_state(ket, wire_order=Wires(wire_order)) - assert qml.math.get_dtype_name(result) == "complex128" - assert np.array_equal(result, expected_state) - @pytest.mark.all_interfaces @pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"]) - def test_expand_state_all_interfaces(self, interface): - """Test that expanding the state over wires preserves interface.""" - mp = StateMP(wires=[4, 2, 0, 1]) - ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface) - result = mp.process_state(ket, wire_order=Wires([1, 2])) - reshaped = qml.math.reshape(result, (2, 2, 2, 2)) - assert qml.math.all(reshaped[1, :, 1, :] == 0) - assert qml.math.allclose(reshaped[0, :, 0, :], np.array([[0.48j, -0.64j], [0.48, 0.36]])) - if interface != "autograd": - # autograd.numpy.pad drops pennylane tensor for some reason - assert qml.math.get_interface(result) == interface + def test_reorder_state_three_wires(self, interface): + """Test that a 3-qubit state can be re-ordered.""" + input_wires = Wires([2, 0, 1]) + output_wires = Wires([1, 2, 0]) + ket = qml.math.arange(8, like=interface) + expected = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires)) + assert qml.math.allclose(result, expected) + assert qml.math.get_interface(ket) == interface @pytest.mark.all_interfaces @pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"]) - def test_expand_state_batched_all_interfaces(self, interface): - """Test that expanding the state over wires preserves interface.""" - mp = StateMP(wires=[4, 2, 0, 1]) - ket = qml.math.array( - [ - [0.48j, 0.48, -0.64j, 0.36], - [0.3, 0.4, 0.5, 1 / np.sqrt(2)], - [-0.3, -0.4, -0.5, -1 / np.sqrt(2)], - ], - like=interface, - ) - result = mp.process_state(ket, wire_order=Wires([1, 2])) - assert qml.math.shape(result) == (3, 16) - reshaped = qml.math.reshape(result, (3, 2, 2, 2, 2)) - assert qml.math.all(reshaped[:, 1, :, 1, :] == 0) - assert qml.math.allclose( - reshaped[:, 0, :, 0, :], - np.array( - [ - [[0.48j, -0.64j], [0.48, 0.36]], - [[0.3, 0.5], [0.4, 1 / np.sqrt(2)]], - [[-0.3, -0.5], [-0.4, -1 / np.sqrt(2)]], - ], - ), - ) - if interface != "autograd": - # autograd.numpy.pad drops pennylane tensor for some reason - assert qml.math.get_interface(result) == interface + def test_reorder_state_three_wires_batched(self, interface): + """Test that a batched, 3-qubit state can be re-ordered.""" + input_wires = Wires([2, 0, 1]) + output_wires = Wires([1, 2, 0]) + ket = qml.math.reshape(qml.math.arange(16, like=interface), (2, 8)) + expected = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + expected = np.array([expected, expected + 8]) + result = StateMP(wires=output_wires).process_state(ket, wire_order=Wires(input_wires)) + assert qml.math.allclose(result, expected) + assert qml.math.shape(result) == (2, 8) + assert qml.math.get_interface(ket) == interface @pytest.mark.jax - @pytest.mark.parametrize( - "wires,expected", - [ - ([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])), - ([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])), - ], - ) - def test_state_jax_jit(self, wires, expected): - """Test that re-ordering and expanding works with jax-jit.""" + def test_state_jax_jit(self): + """Test that re-ordering works with jax-jit.""" import jax @jax.jit def get_state(ket): - return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1])) + return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1])) result = get_state(jax.numpy.array([0.48j, 0.48, -0.64j, 0.36])) - assert qml.math.allclose(result, expected) + assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36])) assert isinstance(result, jax.Array) @pytest.mark.tf - @pytest.mark.parametrize( - "wires,expected", - [ - ([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])), - ([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])), - ], - ) - def test_state_tf_function(self, wires, expected): - """Test that re-ordering and expanding works with tf.function.""" + def test_state_tf_function(self): + """Test that re-ordering works with tf.function.""" import tensorflow as tf @tf.function def get_state(ket): - return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1])) + return StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([0, 1])) result = get_state(tf.Variable([0.48j, 0.48, -0.64j, 0.36])) - assert qml.math.allclose(result, expected) + assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36])) assert isinstance(result, tf.Tensor) def test_wire_ordering_error(self): """Test that a wire order error is raised when unknown wires are given.""" - with pytest.raises(WireError, match=r"Unexpected wires \{2\} found in wire order"): - StateMP(wires=[0, 1]).process_state([1, 0], wire_order=[2]) + with pytest.raises(WireError, match=r"Unexpected unique wires found"): + StateMP(wires=[0, 1]).process_state([1, 0], wire_order=Wires(2)) class TestDensityMatrixMP: