Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serializable parametric pulse #7821

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
afe4f96
Symbolic parametric pulse
nkanazawa1989 Feb 27, 2022
9a8a582
Improve performance of parametrized pulse evaluation
wshanks Mar 16, 2022
6f2496a
Implement Constant pulse using Piecewise
wshanks Mar 17, 2022
0424636
Fall back to sympy for parametric pulse evaluation
wshanks Mar 17, 2022
1bf11fa
Use sympy Symbol with sympy lambdify
wshanks Mar 17, 2022
2335f32
WIP: cache symbolic pulse lambda functions
wshanks Mar 18, 2022
2af235d
Merge pull request #48 from wshanks/upgrade/serializable-parametric-p…
nkanazawa1989 Mar 23, 2022
0cc7383
move lambdify to init subclass for called once
nkanazawa1989 Mar 23, 2022
a59af77
turn parameter properties into attribute and remove slots
nkanazawa1989 Mar 23, 2022
21ef762
remove attribute and add __getattr__ and readd slots for better perfo…
nkanazawa1989 Mar 23, 2022
509470a
move symbolic funcs directly to _define
nkanazawa1989 Mar 23, 2022
ed2f661
Convert numerical_func to staticmethod
nkanazawa1989 Mar 23, 2022
c4b41d5
remove symbolic add constraints, numerical func is renamed to definit…
nkanazawa1989 Mar 27, 2022
71ef841
revert changes to parametric pulse and create new symbolic pulse file
nkanazawa1989 Apr 12, 2022
cf7f302
add ITE-like program
nkanazawa1989 Apr 12, 2022
c2981d0
sympy runtime import
nkanazawa1989 Apr 12, 2022
cf17341
update test notebook
nkanazawa1989 Apr 12, 2022
ba685ef
Merge branch 'main' of github.com:Qiskit/qiskit-terra into upgrade/se…
nkanazawa1989 Apr 12, 2022
f2ef950
add attribute docs
nkanazawa1989 Apr 12, 2022
2ea6390
fix non-constraints
nkanazawa1989 Apr 12, 2022
f48796e
remove pulse type instance variable
nkanazawa1989 Apr 12, 2022
0738497
use descriptor
nkanazawa1989 Apr 12, 2022
92f42eb
remove redundant comment
nkanazawa1989 Apr 12, 2022
1440a23
Merge branch 'main' of github.com:Qiskit/qiskit-terra into upgrade/se…
nkanazawa1989 Apr 13, 2022
4345c3b
Update
nkanazawa1989 Apr 13, 2022
1c63ddb
keep raw data for QPY encoding
nkanazawa1989 Apr 13, 2022
d121b0d
update unittest
nkanazawa1989 Apr 13, 2022
48ff3cc
update helper function name
nkanazawa1989 Apr 13, 2022
b9cab67
add reno
nkanazawa1989 Apr 13, 2022
42b3dd9
remove notebook
nkanazawa1989 Apr 13, 2022
761235d
fix lint
nkanazawa1989 Apr 13, 2022
cc12bf8
fix unittest and logic
nkanazawa1989 Apr 13, 2022
6c21edd
add more docs
nkanazawa1989 Apr 14, 2022
43e0e0b
review comments
nkanazawa1989 Apr 15, 2022
6b956de
Merge branch 'main' of github.com:Qiskit/qiskit-terra into upgrade/se…
nkanazawa1989 Apr 18, 2022
eefa5e4
lint fix
nkanazawa1989 Apr 18, 2022
252f62b
fix documentation
nkanazawa1989 Apr 18, 2022
f24e648
minor drawer fix
nkanazawa1989 Apr 19, 2022
4d191be
documentation upgrade
nkanazawa1989 Apr 21, 2022
2c8e7f8
Merge branch 'main' of github.com:Qiskit/qiskit-terra into upgrade/se…
nkanazawa1989 Apr 21, 2022
3a0f506
review comment misc
nkanazawa1989 Apr 28, 2022
d5517d5
remove abstract class methods
nkanazawa1989 May 2, 2022
be2ed0c
add error handling for amplitude
nkanazawa1989 May 5, 2022
67ac61f
treat amp as a special parameter
nkanazawa1989 May 6, 2022
78288ac
Remove expressions for amplitude validation
nkanazawa1989 May 19, 2022
cbada64
Merge branch 'main' of github.com:Qiskit/qiskit-terra into upgrade/se…
nkanazawa1989 May 20, 2022
ecb4ea7
support symengine
nkanazawa1989 Jun 2, 2022
45ae755
use real=False option
nkanazawa1989 Jun 2, 2022
a4a56db
review comment
nkanazawa1989 Jun 6, 2022
d957458
Merge branch 'upgrade/serializable-parametric-pulse-symengine' into u…
nkanazawa1989 Jun 6, 2022
6776743
- fix attribute
nkanazawa1989 Jun 6, 2022
d4e8c43
undo change to requirements-dev
nkanazawa1989 Jun 6, 2022
41bf4f5
fix __getattr__ mechanism
nkanazawa1989 Jun 6, 2022
2aa4f5a
fix type hint reference
nkanazawa1989 Jun 6, 2022
2e521a7
simplification
nkanazawa1989 Jun 6, 2022
09855d7
move amp from constructor to `parameters` dict
nkanazawa1989 Jun 7, 2022
b99eb5c
review comment
nkanazawa1989 Jun 8, 2022
ac23d6f
fix bug
nkanazawa1989 Jun 8, 2022
401c1f4
fix typo
nkanazawa1989 Jun 8, 2022
686e77b
add eval_conditions to skip waveform generation
nkanazawa1989 Jun 8, 2022
84344c2
fall back to sympy lamdify when function is not supported
nkanazawa1989 Jun 9, 2022
914018e
documentation update
nkanazawa1989 Jun 12, 2022
045d67b
replace eval_conditions with valid_amp_conditions
nkanazawa1989 Jun 12, 2022
3b2c571
update hashing and equality, redefine expressions more immutably
nkanazawa1989 Jun 12, 2022
8edaba9
add error message for missing parameter
nkanazawa1989 Jun 12, 2022
9b0adc3
cleanup
nkanazawa1989 Jun 12, 2022
cda7336
check parameter before hashing
nkanazawa1989 Jun 13, 2022
ef35cd9
move amp check to constructor
nkanazawa1989 Jun 13, 2022
15e9f7c
add envelope to hash
nkanazawa1989 Jun 14, 2022
d844253
update docs
nkanazawa1989 Jun 15, 2022
526aaed
Update qiskit/pulse/library/symbolic_pulses.py
nkanazawa1989 Jun 15, 2022
87d3538
Merge branch 'main' into upgrade/serializable-parametric-pulse
nkanazawa1989 Jun 15, 2022
be2b4e0
lint
nkanazawa1989 Jun 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 75 additions & 43 deletions qiskit/pulse/library/symbolic_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import functools
from typing import Any, Dict, Optional, Union, Callable
from typing import Any, Dict, List, Optional, Union, Callable

import numpy as np

Expand Down Expand Up @@ -90,7 +90,30 @@ def _validate_amplitude_limit(symbolic_pulse: "SymbolicPulse") -> bool:
return np.any(np.abs(symbolic_pulse.get_waveform().samples) > 1.0)


class LamdifiedExpression:
def _get_expression_args(
expr: sym.Expr,
params: Dict[str, float],
exclude: Optional[List[str]] = None,
) -> List[float]:
"""A helper function to get argument to evaluate expression.

Args:
expr: Symbolic expression to evaluate.
params: Dictionary of parameter, which is a superset of expression arguments.
exclude: Parameter to exclude from arguments.

Returns:
Arguments passed to the lambdified expression.
"""
args = []
for symbol in sorted(expr.free_symbols, key=lambda s: s.name):
if exclude and symbol.name in exclude:
continue
args.append(params[symbol.name])
return args


class LambdifiedExpression:
"""Descriptor to lambdify symbolic expression with cache.

When new symbolic expression is set for the first time,
Expand Down Expand Up @@ -302,15 +325,16 @@ def Sawtooth(duration, amp, freq, name):
__slots__ = (
"_amp",
"_pulse_type",
"_param_names",
"_param_vals",
"_params",
"envelope",
"constraints",
"eval_conditions",
)

# Lambdify caches keyed on sympy expressions. Returns the corresponding callable.
_envelope_lambdify = LamdifiedExpression("envelope")
_constraints_lambdify = LamdifiedExpression("constraints")
_envelope_lam = LambdifiedExpression("envelope")
_constraints_lam = LambdifiedExpression("constraints")
_eval_conditions_lam = LambdifiedExpression("eval_conditions")

def __init__(
self,
Expand All @@ -321,6 +345,7 @@ def __init__(
limit_amplitude: Optional[bool] = None,
envelope: Optional[sym.Expr] = None,
constraints: Optional[sym.Expr] = None,
eval_conditions: Optional[sym.Expr] = None,
):
"""Create a parametric pulse.

Expand All @@ -333,6 +358,9 @@ def __init__(
waveform to 1. The default is ``True`` and the amplitude is constrained to 1.
envelope: Pulse envelope expression.
constraints: Pulse parameter constraint expression.
eval_conditions: Extra conditions to evaluate full waveform to check the
amplitude limit. This can be provided if ``envelope`` function is not
normalized to 1.0.

Raises:
PulseError: When not all parameters are listed in the attribute :attr:`PARAM_DEF`.
Expand All @@ -356,21 +384,21 @@ def __init__(
)
self._amp = parameters.pop("amp")
self._pulse_type = pulse_type
self._param_names = tuple(parameters.keys())
self._param_vals = tuple(parameters.values())
self._params = parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked storing the params as tuples. It's probably premature optimization though (and maybe bad if .parameters gets used a lot). I wish Python had a compact frozen dict type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this is called multiple times. Seems like we have frozendict in python. https://pypi.org/project/frozendict/ What do you think of adding this? @mtreinish

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want truly an immutable type that's really tricky to do in python. You can create one like frozendict pretty easily without that library just something that implements __getitem__() and the other read only mapping protocol functions will work. Basically just subclass collections.abc.Mapping. You could also do it via rust via a HashMap or IndexMap rust struct if you wanted something with statically typed keys which performed better. But that only limits it to top level immutability (basically only blocking inserts and value replacements) you'll always still be able to modify a value inplace. For example, using frozendict you could do something like:

import frozendict
test = frozendict.frozendict({'a': []})
test['a'].append(2)
print(test)

Would print: frozendict.frozendict({'a': [2]})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we cast amp to complex here if it is in the parameters and not a ParameterExpression? Right now this is only done in all the subclasses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but treating amp as special parameter is a convention in IBM backend and this code is backend-agnostic. So one might want treat amp as a real value in phasor representation (I really prefer this).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I agree with the spirit of not treating amp, but we are still treating it specially in all the subclasses and in the parameter assignment. I think it would be more consistent just to cast it to complex here for now.

I'd like to look at what it would take to get the backend to accept float or complex for all the parameters so terra does not have to special case parameters, but that would be follow up work.

I also prefer the phasor representation. I think one can use a * exp(1j * phi) for the amp paremeter where a and phi are paraemters to get this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. This is forcibly typecasted in parameter assignment anyways. Done in ef35cd9. It would be great if we can remove this check since current implementation means we cannot write pulses with phase representation.

I think one can use a * exp(1j * phi) for the amp paremeter where a and phi are paraemters to get this?

Yes, this is a technique I often use, i.e.

amp = Parameter("amp")
phase = Parameter("phase")
duration = Parameter("duration")
sigma = Parameter("sigma")

Gaussian(duration=duration, amp=amp * exp(1j * phase), sigma=sigma)


self.envelope = envelope
self.constraints = constraints
self.eval_conditions = eval_conditions

def __getattr__(self, item):
# Get pulse parameters with attribute-like access.
param_names = object.__getattribute__(self, "_param_names")
param_vals = object.__getattribute__(self, "_param_vals")
if item == "amp":
return object.__getattribute__(self, "_amp")
if item not in param_names:

params = object.__getattribute__(self, "_params")
if item not in params:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
return param_vals[param_names.index(item)]
return params[item]

@property
def pulse_type(self) -> str:
Expand Down Expand Up @@ -402,15 +430,9 @@ def get_waveform(self) -> Waveform:
raise PulseError("Pulse envelope expression is not assigned.")

times = np.arange(0, self.duration) + 1 / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this midpoint sampling work the way one would expect with gaussian pulse shapes being lifted to 0? The zeroes are at -1 and duration + 1. I think the backend has the zeroes at -1 and duration for samples range(duration), so the zeroes are half a sample further away for these pulses than for the backend.

Copy link
Contributor Author

@nkanazawa1989 nkanazawa1989 Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. So far I have never heard critical failure in pulse simulator. As long as we have enough duration/sigma this doesn't become serious problem. I think t_zero value is insensitive to sampling strategy (because it is computed before sampling), thus offset value should be the same. We are sampling the same waveform at slightly different point.

I agree there is undiscoverable difference in frontend and backend as discussed in this #7659 (comment). I sometime see the difference of calibrated parameters in parametric pulse v.s. waveform. I think SymbolicPulse will improve this situation. In principle backend compiler doesn't need to have own pulse factory. It just need to convert the symbolic pulse into waveform and reuse compiler pass for the waveform (smart controller could do runtime waveform generation in microarchitecture, then it doesn't have waveform memory to save it). In this sense, I feel we don't need to conform to the backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the sampling behavior here match exactly what the old parametric pulse classes would produce? If so, it is fine to leave this. If it is already a minor change in behavior, why not try to match what the IBM backend does exactly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conforms to the behavior of parametric pulse classes. We can update this with extra work for unittest (in unittest reference data is generated with midpoint sampler).

params = self.parameters
fargs = [times, *_get_expression_args(self.envelope, self.parameters, exclude=["t"])]

func_args = [times]
for name in sorted(map(lambda s: s.name, self.envelope.free_symbols)):
if name == "t":
continue
func_args.append(params[name])

return Waveform(samples=self._amp * self._envelope_lambdify(*func_args), name=self.name)
return Waveform(samples=self._amp * self._envelope_lam(*fargs), name=self.name)

def validate_parameters(self) -> None:
"""Validate parameters.
Expand All @@ -421,46 +443,56 @@ def validate_parameters(self) -> None:
if self.is_parameterized():
return

if any(p.imag != 0 for p in self._param_vals):
if any(p.imag != 0 for p in self._params.values()):
raise PulseError("Pulse parameters must be real numbers except for 'amp'.")

if self.constraints is not None:
func_args = []
params = self.parameters
for name in sorted(map(lambda s: s.name, self.constraints.free_symbols)):
func_args.append(params[name])

if not bool(self._constraints_lambdify(*func_args)):
fargs = _get_expression_args(self.constraints, self.parameters)
if not bool(self._constraints_lam(*fargs)):
param_repr = ", ".join(f"{p}={v}" for p, v in self.parameters.items())
const_repr = str(self.constraints)
raise PulseError(
f"Assigned parameters {param_repr} violate following constraint: {const_repr}."
)

if self.limit_amplitude and (np.abs(self._amp) > 1.0 or _validate_amplitude_limit(self)):
# Check max amplitude limit by generating waveform.
# We can avoid calling _validate_amplitude_limit when |amp| > 1.0
# which obviously violates the amplitude constraint by definition.
param_repr = ", ".join(f"{p}={v}" for p, v in self.parameters.items())
raise PulseError(
f"Maximum pulse amplitude norm exceeds 1.0 with assigned parameters {param_repr}."
"This can be overruled by setting Pulse.limit_amplitude."
)
if self.limit_amplitude:
violate_amp_limit = False

if np.abs(self._amp) > 1.0:
# Strong condition. |amp| > 1.0 must exceed AWG dynamic range.
violate_amp_limit = True
elif self.eval_conditions is not None:
# Check actual waveform.
# For example, if .envelope function is not normalized,
# |amp| < 1.0 can result in the maximum amplitude > 1.0, depending on the
# selection of other parameters.
fargs = _get_expression_args(self.eval_conditions, self.parameters)
if bool(self._eval_conditions_lam(*fargs)) and _validate_amplitude_limit(self):
violate_amp_limit = True

if violate_amp_limit:
# Check max amplitude limit by generating waveform.
# We can avoid calling _validate_amplitude_limit when |amp| > 1.0
# which obviously violates the amplitude constraint by definition.
param_repr = ", ".join(f"{p}={v}" for p, v in self.parameters.items())
raise PulseError(
f"Maximum pulse amplitude norm exceeds 1.0 with assigned parameters {param_repr}."
"This can be overruled by setting Pulse.limit_amplitude."
)

def is_parameterized(self) -> bool:
"""Return True iff the instruction is parameterized."""
args = (self.duration, self._amp, *self._param_vals)
return any(isinstance(val, ParameterExpression) for val in args)
return any(isinstance(val, ParameterExpression) for val in self.parameters.values())

@property
def parameters(self) -> Dict[str, Any]:
params = {"duration": self.duration, "amp": self._amp}
params.update(dict(zip(self._param_names, self._param_vals)))
params.update(self._params)
return params

def __eq__(self, other: "SymbolicPulse") -> bool:

if self.envelope != other.envelope:
if self.envelope != getattr(other, "envelope", None):
return False

if self.parameters != other.parameters:
Expand All @@ -469,9 +501,7 @@ def __eq__(self, other: "SymbolicPulse") -> bool:
return True

def __hash__(self) -> int:
return hash(
(self._pulse_type, self.duration, self._amp, *self._param_names, *self._param_vals)
)
return hash((self._pulse_type, self.duration, self._amp, *tuple(self._params.items())))

def __repr__(self) -> str:
param_repr = ", ".join(f"{p}={v}" for p, v in self.parameters.items())
Expand Down Expand Up @@ -729,6 +759,7 @@ def __init__(
# In IBM quantum backend, im(beta) == 0 is explicitly checked.
# In Qiskit, we impose real number constraints on all parameters except for 'amp'.
consts_expr = _sigma > 0
eval_conditions_expr = sym.Abs(_beta) > _sigma

super().__init__(
pulse_type=self.__class__.__name__,
Expand All @@ -738,6 +769,7 @@ def __init__(
limit_amplitude=limit_amplitude,
envelope=envelope_expr,
constraints=consts_expr,
eval_conditions=eval_conditions_expr,
)
self.validate_parameters()

Expand Down
17 changes: 7 additions & 10 deletions qiskit/pulse/parameter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,20 @@ def visit_ParametricPulse(self, node: ParametricPulse):
def visit_SymbolicPulse(self, node: SymbolicPulse):
"""Assign parameters to ``SymbolicPulse`` object."""
if node.is_parameterized():
new_values = []
# Assign duration
if isinstance(node.duration, ParameterExpression):
node.duration = self._assign_parameter_expression(node.duration)
if isinstance(node.amp, ParameterExpression):
assigned_amp = self._assign_parameter_expression(node.amp)
if isinstance(node._amp, ParameterExpression):
assigned_amp = self._assign_parameter_expression(node._amp)
if not isinstance(assigned_amp, ParameterExpression):
# Amplitude is complex value
assigned_amp = complex(assigned_amp)
node.amp = assigned_amp
node._amp = assigned_amp
# Assign other parameters
for op_value in node._param_vals:
if isinstance(op_value, ParameterExpression):
new_values.append(self._assign_parameter_expression(op_value))
else:
new_values.append(op_value)
node._param_vals = new_values
for name in node._params:
pval = node._params[name]
if isinstance(pval, ParameterExpression):
node._params[name] = self._assign_parameter_expression(pval)
node.validate_parameters()

return node
Expand Down
8 changes: 4 additions & 4 deletions test/python/pulse/test_pulse_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_drag_validation(self):
wf = Drag(duration=duration, sigma=sigma, amp=amp, beta=beta)
samples = wf.get_waveform().samples
self.assertTrue(max(np.abs(samples)) <= 1)
with self.assertRaises(PulseError):
wf = Drag(duration=duration, sigma=sigma, amp=1.2, beta=beta)
beta = sigma**2
with self.assertRaises(PulseError):
wf = Drag(duration=duration, sigma=sigma, amp=amp, beta=beta)
Expand Down Expand Up @@ -319,15 +321,13 @@ def test_envelope_cache(self):
"""Test speed up of instantiation with lambdify envelope cache."""
drag_instance1 = Drag(duration=100, amp=0.1, sigma=40, beta=3)
drag_instance2 = Drag(duration=100, amp=0.1, sigma=40, beta=3)
self.assertTrue(drag_instance1._envelope_lambdify is drag_instance2._envelope_lambdify)
self.assertTrue(drag_instance1._envelope_lam is drag_instance2._envelope_lam)

def test_constraints_cache(self):
"""Test speed up of instantiation with lambdify constraints cache."""
drag_instance1 = Drag(duration=100, amp=0.1, sigma=40, beta=3)
drag_instance2 = Drag(duration=100, amp=0.1, sigma=40, beta=3)
self.assertTrue(
drag_instance1._constraints_lambdify is drag_instance2._constraints_lambdify
)
self.assertTrue(drag_instance1._constraints_lam is drag_instance2._constraints_lam)

def test_deepcopy(self):
"""Test deep copying instance."""
Expand Down