Skip to content

Commit

Permalink
Introduce observable key for Estimator (#8837)
Browse files Browse the repository at this point in the history
* Introduce observable key

* cast to SparsePauliOp
  • Loading branch information
ikkoham authored Oct 5, 2022
1 parent e4029f6 commit 875b646
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 16 deletions.
9 changes: 5 additions & 4 deletions qiskit/primitives/backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from qiskit.tools.monitor import job_monitor
from qiskit.transpiler import PassManager

from .utils import _circuit_key, init_observable
from .utils import _circuit_key, _observable_key, init_observable


def _run_circuits(
Expand Down Expand Up @@ -245,13 +245,14 @@ def _run(
self._parameters.append(circuit.parameters)
observable_indices = []
for observable in observables:
index = self._observable_ids.get(id(observable))
observable = init_observable(observable)
index = self._observable_ids.get(_observable_key(observable))
if index is not None:
observable_indices.append(index)
else:
observable_indices.append(len(self._observables))
self._observable_ids[id(observable)] = len(self._observables)
self._observables.append(init_observable(observable))
self._observable_ids[_observable_key(observable)] = len(self._observables)
self._observables.append(observable)
job = PrimitiveJob(
self._call, circuit_indices, observable_indices, parameter_values, **run_options
)
Expand Down
13 changes: 8 additions & 5 deletions qiskit/primitives/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from qiskit.utils.deprecation import deprecate_arguments, deprecate_function

from .estimator_result import EstimatorResult
from .utils import _circuit_key
from .utils import _circuit_key, _observable_key, init_observable


class BaseEstimator(ABC):
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
# To guarantee that they exist as instance variable.
# With only dynamic set, the python will not know if the attribute exists or not.
self._circuit_ids: dict[tuple, int] = self._circuit_ids
self._observable_ids: dict[int, int] = self._observable_ids
self._observable_ids: dict[tuple, int] = self._observable_ids

if parameters is None:
self._parameters = [circ.parameters for circ in self._circuits]
Expand Down Expand Up @@ -190,9 +190,12 @@ def __new__(
self._observable_ids = {}
elif isinstance(observables, Iterable):
observables = copy(observables)
self._observable_ids = {id(observable): i for i, observable in enumerate(observables)}
self._observable_ids = {
_observable_key(init_observable(observable)): i
for i, observable in enumerate(observables)
}
else:
self._observable_ids = {id(observables): 0}
self._observable_ids = {_observable_key(init_observable(observables)): 0}
return self

@deprecate_function(
Expand Down Expand Up @@ -324,7 +327,7 @@ def __call__(
"initialize the session."
)
observables = [
self._observable_ids.get(id(observable))
self._observable_ids.get(_observable_key(observable))
if not isinstance(observable, (int, np.integer))
else observable
for observable in observables
Expand Down
15 changes: 11 additions & 4 deletions qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
from .base_estimator import BaseEstimator
from .estimator_result import EstimatorResult
from .primitive_job import PrimitiveJob
from .utils import _circuit_key, bound_circuit_to_instruction, init_circuit, init_observable
from .utils import (
_circuit_key,
_observable_key,
bound_circuit_to_instruction,
init_circuit,
init_observable,
)


class Estimator(BaseEstimator):
Expand Down Expand Up @@ -166,13 +172,14 @@ def _run(
self._parameters.append(circuit.parameters)
observable_indices = []
for observable in observables:
index = self._observable_ids.get(id(observable))
observable = init_observable(observable)
index = self._observable_ids.get(_observable_key(observable))
if index is not None:
observable_indices.append(index)
else:
observable_indices.append(len(self._observables))
self._observable_ids[id(observable)] = len(self._observables)
self._observables.append(init_observable(observable))
self._observable_ids[_observable_key(observable)] = len(self._observables)
self._observables.append(observable)
job = PrimitiveJob(
self._call, circuit_indices, observable_indices, parameter_values, **run_options
)
Expand Down
11 changes: 11 additions & 0 deletions qiskit/primitives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def _circuit_key(circuit: QuantumCircuit, functional: bool = True) -> tuple:
)


def _observable_key(observable: SparsePauliOp) -> tuple:
"""Private key function for SparsePauliOp.
Args:
observable: Input operator.
Returns:
Key for observables.
"""
return tuple(observable.to_list())


def bound_circuit_to_instruction(circuit: QuantumCircuit) -> Instruction:
"""Build an :class:`~qiskit.circuit.Instruction` object from
a :class:`~qiskit.circuit.QuantumCircuit`
Expand Down
3 changes: 1 addition & 2 deletions test/python/primitives/test_backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from qiskit.circuit import QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.opflow import PauliSumOp
from qiskit.primitives import BackendEstimator, EstimatorResult
from qiskit.providers import JobV1
from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2
Expand All @@ -37,7 +36,7 @@ class TestBackendEstimator(QiskitTestCase):
def setUp(self):
super().setUp()
self.ansatz = RealAmplitudes(num_qubits=2, reps=2)
self.observable = PauliSumOp.from_list(
self.observable = SparsePauliOp.from_list(
[
("II", -1.052373245772859),
("IZ", 0.39793742484318045),
Expand Down
13 changes: 12 additions & 1 deletion test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qiskit.circuit.library import RealAmplitudes
from qiskit.opflow import PauliSumOp
from qiskit.primitives import Estimator, EstimatorResult
from qiskit.primitives.utils import _observable_key
from qiskit.providers import JobV1
from qiskit.quantum_info import Operator, SparsePauliOp
from qiskit.test import QiskitTestCase
Expand All @@ -31,7 +32,7 @@ class TestEstimator(QiskitTestCase):
def setUp(self):
super().setUp()
self.ansatz = RealAmplitudes(num_qubits=2, reps=2)
self.observable = PauliSumOp.from_list(
self.observable = SparsePauliOp.from_list(
[
("II", -1.052373245772859),
("IZ", 0.39793742484318045),
Expand Down Expand Up @@ -637,6 +638,16 @@ def test_options(self):
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [-1.307397243478641])

def test_different_circuits(self):
"""Test collision of quantum observables."""

def get_op(i):
op = SparsePauliOp.from_list([("IXIX", i)])
return op

keys = [_observable_key(get_op(i)) for i in range(5)]
self.assertEqual(len(keys), len(set(keys)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 875b646

Please sign in to comment.