Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade unittest equality method #1132

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions releasenotes/notes/add-test-equality-checker-dbe5762d2b6a967f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
developer:
- |
Added the :meth:`QiskitExperimentsTestCase.assertEqualExtended` method for generic equality checks
of Qiskit Experiments class instances in unittests. This is a drop-in replacement of
calling the assertTrue with :meth:`QiskitExperimentsTestCase.json_equiv`.
Note that some Qiskit Experiments classes may not officially implement equality check logic,
although objects may be compared during unittests. Extended equality check is used
for such situations.
- |
The following unittest test case methods will be deprecated:

* :meth:`QiskitExperimentsTestCase.json_equiv`
* :meth:`QiskitExperimentsTestCase.ufloat_equiv`
* :meth:`QiskitExperimentsTestCase.analysis_result_equiv`
* :meth:`QiskitExperimentsTestCase.curve_fit_data_equiv`
* :meth:`QiskitExperimentsTestCase.experiment_data_equiv`

One can now use the :func:`~test.extended_equality.is_equivalent` function instead.
This function internally dispatches the logic for equality check.
- |
The default behavior of :meth:`QiskitExperimentsTestCase.assertRoundTripSerializable` and
:meth:`QiskitExperimentsTestCase.assertRoundTripPickle` when `check_func` is not
provided was upgraded. These methods now compare the decoded instance with
:func:`~test.extended_equality.is_equivalent`, rather than
delegating to the native `assertEqual` unittest method.
One writing a unittest for serialization no longer need to explicitly set checker function.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ qiskit-aer>=0.11.0
pandas>=1.1.5
cvxpy>=1.1.15
pylatexenc
multimethod
scikit-learn
sphinx-copybutton
# Pin versions below because of build errors
Expand Down
244 changes: 98 additions & 146 deletions test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,22 @@
Qiskit Experiments test case class
"""

import dataclasses
import json
import pickle
import warnings
from typing import Any, Callable, Optional

import numpy as np
import uncertainties
from lmfit import Model
from qiskit.test import QiskitTestCase
from qiskit_experiments.data_processing import DataAction, DataProcessor
from qiskit_experiments.framework.experiment_data import ExperimentStatus
from qiskit.utils.deprecation import deprecate_func

from qiskit_experiments.framework import (
ExperimentDecoder,
ExperimentEncoder,
ExperimentData,
BaseExperiment,
BaseAnalysis,
)
from qiskit_experiments.visualization import BaseDrawer
from qiskit_experiments.curve_analysis.curve_data import CurveFitResult
from qiskit_experiments.framework.experiment_data import ExperimentStatus
from .extended_equality import is_equivalent


class QiskitExperimentsTestCase(QiskitTestCase):
Expand Down Expand Up @@ -76,15 +71,52 @@ def assertExperimentDone(
msg="All threads are executed but status is not DONE. " + experiment_data.errors(),
)

def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] = None):
def assertEqualExtended(
self,
first: Any,
second: Any,
*,
msg: Optional[str] = None,
strict_type: bool = False,
):
"""Extended equality assertion which covers Qiskit Experiments classes.

.. note::
Some Qiskit Experiments class may intentionally avoid implementing
the equality dunder method, or may be used in some unusual situations.
These are mainly caused by to JSON round trip situation, and some custom classes
doesn't guarantee object equality after round trip.
This assertion function forcibly compares input two objects with
the custom equality checker, which is implemented for unittest purpose.

Args:
first: First object to compare.
second: Second object to compare.
msg: Optional. Custom error message issued when first and second object are not equal.
strict_type: Set True to enforce type check before comparison.
"""
default_msg = f"{first} != {second}"

self.assertTrue(
is_equivalent(first, second, strict_type=strict_type),
msg=msg or default_msg,
)

def assertRoundTripSerializable(
self,
obj: Any,
*,
check_func: Optional[Callable] = None,
strict_type: bool = False,
):
"""Assert that an object is round trip serializable.

Args:
obj: the object to be serialized.
check_func: Optional, a custom function ``check_func(a, b) -> bool``
to check equality of the original object with the decoded
object. If None the ``__eq__`` method of the original
object will be used.
to check equality of the original object with the decoded
object. If None :meth:`.assertEqualExtended` is called.
strict_type: Set True to enforce type check before comparison.
"""
try:
encoded = json.dumps(obj, cls=ExperimentEncoder)
Expand All @@ -94,20 +126,27 @@ def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] =
decoded = json.loads(encoded, cls=ExperimentDecoder)
except TypeError:
self.fail("JSON deserialization raised unexpectedly.")
if check_func is None:
self.assertEqual(obj, decoded)
else:

if check_func is not None:
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
else:
self.assertEqualExtended(obj, decoded, strict_type=strict_type)

def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None):
def assertRoundTripPickle(
self,
obj: Any,
*,
check_func: Optional[Callable] = None,
strict_type: bool = False,
):
"""Assert that an object is round trip serializable using pickle module.

Args:
obj: the object to be serialized.
check_func: Optional, a custom function ``check_func(a, b) -> bool``
to check equality of the original object with the decoded
object. If None the ``__eq__`` method of the original
object will be used.
to check equality of the original object with the decoded
object. If None :meth:`.assertEqualExtended` is called.
strict_type: Set True to enforce type check before comparison.
"""
try:
encoded = pickle.dumps(obj)
Expand All @@ -117,150 +156,63 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None)
decoded = pickle.loads(encoded)
except TypeError:
self.fail("pickle deserialization raised unexpectedly.")
if check_func is None:
self.assertEqual(obj, decoded)
else:

if check_func is not None:
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
else:
self.assertEqualExtended(obj, decoded, strict_type=strict_type)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def json_equiv(cls, data1, data2) -> bool:
"""Check if two experiments are equivalent by comparing their configs"""
# pylint: disable = too-many-return-statements
configurable_type = (BaseExperiment, BaseAnalysis, BaseDrawer)
compare_repr = (DataAction, DataProcessor)
list_type = (list, tuple, set)
skipped = tuple()

if isinstance(data1, skipped) and isinstance(data2, skipped):
warnings.warn(f"Equivalence check for data {data1.__class__.__name__} is skipped.")
return True
elif isinstance(data1, configurable_type) and isinstance(data2, configurable_type):
return cls.json_equiv(data1.config(), data2.config())
elif dataclasses.is_dataclass(data1) and dataclasses.is_dataclass(data2):
# not using asdict. this copies all objects.
return cls.json_equiv(data1.__dict__, data2.__dict__)
elif isinstance(data1, dict) and isinstance(data2, dict):
if set(data1) != set(data2):
return False
return all(cls.json_equiv(data1[k], data2[k]) for k in data1.keys())
elif isinstance(data1, np.ndarray) or isinstance(data2, np.ndarray):
return np.allclose(data1, data2)
elif isinstance(data1, list_type) and isinstance(data2, list_type):
return all(cls.json_equiv(e1, e2) for e1, e2 in zip(data1, data2))
elif isinstance(data1, uncertainties.UFloat) and isinstance(data2, uncertainties.UFloat):
return cls.ufloat_equiv(data1, data2)
elif isinstance(data1, Model) and isinstance(data2, Model):
return cls.json_equiv(data1.dumps(), data2.dumps())
elif isinstance(data1, CurveFitResult) and isinstance(data2, CurveFitResult):
return cls.curve_fit_data_equiv(data1, data2)
elif isinstance(data1, compare_repr) and isinstance(data2, compare_repr):
# otherwise compare instance representation
return repr(data1) == repr(data2)

return data1 == data2
return is_equivalent(data1, data2)

@staticmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def ufloat_equiv(data1: uncertainties.UFloat, data2: uncertainties.UFloat) -> bool:
"""Check if two values with uncertainties are equal. No correlation is considered."""
return data1.n == data2.n and data1.s == data2.s
return is_equivalent(data1, data2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def analysis_result_equiv(cls, result1, result2):
"""Test two analysis results are equivalent"""
# Check basic attributes skipping service which is not serializable
for att in [
"name",
"value",
"extra",
"device_components",
"result_id",
"experiment_id",
"chisq",
"quality",
"verified",
"tags",
"auto_save",
"source",
]:
if not cls.json_equiv(getattr(result1, att), getattr(result2, att)):
return False
return True
return is_equivalent(result1, result2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def curve_fit_data_equiv(cls, data1, data2):
"""Test two curve fit result are equivalent."""
for att in [
"method",
"model_repr",
"success",
"nfev",
"message",
"dof",
"init_params",
"chisq",
"reduced_chisq",
"aic",
"bic",
"params",
"var_names",
"x_data",
"y_data",
"covar",
]:
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
return False
return True
return is_equivalent(data1, data2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def experiment_data_equiv(cls, data1, data2):
"""Check two experiment data containers are equivalent"""

# Check basic attributes
# Skip non-compatible backend
for att in [
"experiment_id",
"experiment_type",
"parent_id",
"tags",
"job_ids",
"figure_names",
"share_level",
"metadata",
]:
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
return False

# Check length of data, results, child_data
# check for child data attribute so this method still works for
# DbExperimentData
if hasattr(data1, "child_data"):
child_data1 = data1.child_data()
else:
child_data1 = []
if hasattr(data2, "child_data"):
child_data2 = data2.child_data()
else:
child_data2 = []

if (
len(data1.data()) != len(data2.data())
or len(data1.analysis_results()) != len(data2.analysis_results())
or len(child_data1) != len(child_data2)
):
return False

# Check data
if not cls.json_equiv(data1.data(), data2.data()):
return False

# Check analysis results
for result1, result2 in zip(data1.analysis_results(), data2.analysis_results()):
if not cls.analysis_result_equiv(result1, result2):
return False

# Check child data
for child1, child2 in zip(child_data1, child_data2):
if not cls.experiment_data_equiv(child1, child2):
return False

return True
return is_equivalent(data1, data2)
2 changes: 1 addition & 1 deletion test/calibration/test_calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def test_serialization(self):
cals = Calibrations.from_backend(backend, libraries=[library])
cals.add_parameter_value(0.12345, "amp", 3, "x")

self.assertRoundTripSerializable(cals, self.json_equiv)
self.assertRoundTripSerializable(cals)

def test_equality(self):
"""Test the equal method on calibrations."""
Expand Down
2 changes: 1 addition & 1 deletion test/curve_analysis/test_baseclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TestCurveAnalysis(CurveAnalysisTestCase):
def test_roundtrip_serialize(self):
"""A testcase for serializing analysis instance."""
analysis = CurveAnalysis(models=[ExpressionModel(expr="par0 * x + par1", name="test")])
self.assertRoundTripSerializable(analysis, check_func=self.json_equiv)
self.assertRoundTripSerializable(analysis)

def test_parameters(self):
"""A testcase for getting fit parameters with attribute."""
Expand Down
6 changes: 3 additions & 3 deletions test/data_processing/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,14 @@ def test_json_single_node(self):
"""Check if the data processor is serializable."""
node = MinMaxNormalize()
processor = DataProcessor("counts", [node])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

def test_json_multi_node(self):
"""Check if the data processor with multiple nodes is serializable."""
node1 = MinMaxNormalize()
node2 = AverageData(axis=2)
processor = DataProcessor("counts", [node1, node2])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

def test_json_trained(self):
"""Check if trained data processor is serializable and still work."""
Expand All @@ -405,7 +405,7 @@ def test_json_trained(self):
main_axes=np.array([[1, 0]]), scales=[1.0], i_means=[0.0], q_means=[0.0]
)
processor = DataProcessor("memory", data_actions=[node])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

serialized = json.dumps(processor, cls=ExperimentEncoder)
loaded_processor = json.loads(serialized, cls=ExperimentDecoder)
Expand Down
Loading