Skip to content

Commit

Permalink
Add option to allow passing unknown parameters to ParameterExpression…
Browse files Browse the repository at this point in the history
….bind (#9304)

* add allow_unknown_parameters argument to ParameterExprression.bind

* format docs

Co-authored-by: Julien Gacon <[email protected]>

* improve docstring

* also add the argument to subs

* update Parameter.subs

Co-authored-by: Julien Gacon <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 11, 2023
1 parent 396b06c commit 804665b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 29 deletions.
13 changes: 11 additions & 2 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from uuid import uuid4

from qiskit.circuit.exceptions import CircuitError
from qiskit.utils import optionals as _optionals

from .parameterexpression import ParameterExpression


Expand Down Expand Up @@ -86,9 +88,16 @@ def __init__(self, name: str):
symbol = symengine.Symbol(name)
super().__init__(symbol_map={self: symbol}, expr=symbol)

def subs(self, parameter_map: dict):
def subs(self, parameter_map: dict, allow_unknown_parameters: bool = False):
"""Substitute self with the corresponding parameter in ``parameter_map``."""
return parameter_map[self]
if self in parameter_map:
return parameter_map[self]
if allow_unknown_parameters:
return self
raise CircuitError(
"Cannot bind Parameters ({}) not present in "
"expression.".format([str(p) for p in parameter_map])
)

@property
def name(self):
Expand Down
65 changes: 39 additions & 26 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,17 @@ def assign(self, parameter, value: ParameterValueType) -> "ParameterExpression":
return self.subs({parameter: value})
return self.bind({parameter: value})

def bind(self, parameter_values: Dict) -> "ParameterExpression":
def bind(
self, parameter_values: Dict, allow_unknown_parameters: bool = False
) -> "ParameterExpression":
"""Binds the provided set of parameters to their corresponding values.
Args:
parameter_values: Mapping of Parameter instances to the numeric value to which
they will be bound.
allow_unknown_parameters: If ``False``, raises an error if ``parameter_values``
contains Parameters in the keys outside those present in the expression.
If ``True``, any such parameters are simply ignored.
Raises:
CircuitError:
Expand All @@ -108,14 +113,15 @@ def bind(self, parameter_values: Dict) -> "ParameterExpression":
A new expression parameterized by any parameters which were not bound by
parameter_values.
"""

self._raise_if_passed_unknown_parameters(parameter_values.keys())
if not allow_unknown_parameters:
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

symbol_values = {}
for parameter, value in parameter_values.items():
param_expr = self._parameter_symbols[parameter]
symbol_values[param_expr] = value
if parameter in self._parameters:
param_expr = self._parameter_symbols[parameter]
symbol_values[param_expr] = value

bound_symbol_expr = self._symbol_expr.subs(symbol_values)

Expand All @@ -140,12 +146,17 @@ def bind(self, parameter_values: Dict) -> "ParameterExpression":

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)

def subs(self, parameter_map: Dict) -> "ParameterExpression":
def subs(
self, parameter_map: Dict, allow_unknown_parameters: bool = False
) -> "ParameterExpression":
"""Returns a new Expression with replacement Parameters.
Args:
parameter_map: Mapping from Parameters in self to the ParameterExpression
instances with which they should be replaced.
allow_unknown_parameters: If ``False``, raises an error if ``parameter_map``
contains Parameters in the keys outside those present in the expression.
If ``True``, any such parameters are simply ignored.
Raises:
CircuitError:
Expand All @@ -156,36 +167,38 @@ def subs(self, parameter_map: Dict) -> "ParameterExpression":
Returns:
A new expression with the specified parameters replaced.
"""
inbound_parameters = set()
inbound_names = {}
for replacement_expr in parameter_map.values():
for p in replacement_expr.parameters:
inbound_parameters.add(p)
inbound_names[p.name] = p

self._raise_if_passed_unknown_parameters(parameter_map.keys())
if not allow_unknown_parameters:
self._raise_if_passed_unknown_parameters(parameter_map.keys())

inbound_names = {
p.name: p
for replacement_expr in parameter_map.values()
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
p: s for p, s in self._parameter_symbols.items() if p not in parameter_map
}

if _optionals.HAS_SYMENGINE:
import symengine

new_parameter_symbols = {p: symengine.Symbol(p.name) for p in inbound_parameters}
symbol_type = symengine.Symbol
else:
from sympy import Symbol

new_parameter_symbols = {p: Symbol(p.name) for p in inbound_parameters}

# Include existing parameters in self not set to be replaced.
new_parameter_symbols.update(
{p: s for p, s in self._parameter_symbols.items() if p not in parameter_map}
)
symbol_type = Symbol

# If new_param is an expr, we'll need to construct a matching sympy expr
# but with our sympy symbols instead of theirs.

symbol_map = {
self._parameter_symbols[old_param]: new_param._symbol_expr
for old_param, new_param in parameter_map.items()
}
symbol_map = {}
for old_param, new_param in parameter_map.items():
if old_param in self._parameters:
symbol_map[self._parameter_symbols[old_param]] = new_param._symbol_expr
for p in new_param.parameters:
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added `allow_unknown_parameters` argument to
:meth:`~.ParameterExpression.bind` and :meth:`~.ParameterExpression.subs`
to allow passing a dictionary containing unknown parameters without causing an error to be raised.
40 changes: 39 additions & 1 deletion test/python/circuit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def raise_if_parameter_table_invalid(circuit): # pylint: disable=invalid-name

@ddt
class TestParameters(QiskitTestCase):
"""QuantumCircuit Operations tests."""
"""Test Parameters."""

def test_gate(self):
"""Test instantiating gate with variable parameters"""
Expand Down Expand Up @@ -186,6 +186,13 @@ def test_bind_parameters_anonymously(self):
bqc_list = getattr(qc, assign_fun)(param_dict)
self.assertEqual(bqc_anonymous, bqc_list)

def test_bind_parameters_allow_unknown(self):
"""Test binding parameters allowing unknown parameters."""
a = Parameter("a")
b = Parameter("b")
c = a.bind({a: 1, b: 1}, allow_unknown_parameters=True)
self.assertEqual(c, a.bind({a: 1}))

def test_bind_half_single_precision(self):
"""Test binding with 16bit and 32bit floats."""
phase = Parameter("phase")
Expand Down Expand Up @@ -1155,6 +1162,26 @@ def test_parametervector_resize(self):
self.assertIs(element, vec[1])
self.assertListEqual([param.name for param in vec], _paramvec_names("x", 3))

def test_raise_if_sub_unknown_parameters(self):
"""Verify we raise if asked to sub a parameter not in self."""
x = Parameter("x")

y = Parameter("y")
z = Parameter("z")

with self.assertRaisesRegex(CircuitError, "not present"):
x.subs({y: z})

def test_sub_allow_unknown_parameters(self):
"""Verify we raise if asked to sub a parameter not in self."""
x = Parameter("x")

y = Parameter("y")
z = Parameter("z")

subbed = x.subs({y: z}, allow_unknown_parameters=True)
self.assertEqual(subbed, x)


def _construct_circuit(param, qr):
qc = QuantumCircuit(qr)
Expand Down Expand Up @@ -1246,6 +1273,17 @@ def test_raise_if_sub_unknown_parameters(self):
with self.assertRaisesRegex(CircuitError, "not present"):
expr.subs({y: z})

def test_sub_allow_unknown_parameters(self):
"""Verify we raise if asked to sub a parameter not in self."""
x = Parameter("x")
expr = x + 2

y = Parameter("y")
z = Parameter("z")

subbed = expr.subs({y: z}, allow_unknown_parameters=True)
self.assertEqual(subbed, expr)

def test_raise_if_subbing_in_parameter_name_conflict(self):
"""Verify we raise if substituting in conflicting parameter names."""
x = Parameter("x")
Expand Down

0 comments on commit 804665b

Please sign in to comment.