Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add json magic methods to the QCVV library to facilitate saving and reloading using cirq #1138

Merged
merged 9 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 39 additions & 59 deletions docs/source/apps/supermarq/qcvv/qcvv_css.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion supermarq-benchmarks/supermarq/qcvv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A toolkit of QCVV routines."""

from .base_experiment import QCVVExperiment, QCVVResults, Sample
from .irb import IRB, IRBResults
from .irb import IRB, IRBResults, RBResults
from .xeb import XEB, XEBResults

__all__ = [
Expand All @@ -12,4 +12,5 @@
"IRBResults",
"XEB",
"XEBResults",
"RBResults",
]
135 changes: 133 additions & 2 deletions supermarq-benchmarks/supermarq/qcvv/base_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,46 @@

import functools
import numbers
import pathlib
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import cirq
import cirq_superstaq as css
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import supermarq

if TYPE_CHECKING:
from typing_extensions import Self


def qcvv_resolver(cirq_type: str) -> type[Any] | None:
"""Resolves string's referencing classes in the QCVV library. Used by `cirq.read_json()`
to deserialize.

Args:
cirq_type: The type being resolved

Returns:
The corresponding type object (if found) else None

Raises:
ValueError: If the provided type is not resolvable
"""
prefix = "supermarq.qcvv."
if cirq_type.startswith(prefix):
name = cirq_type[len(prefix) :]
if name in supermarq.qcvv.__all__:
return getattr(supermarq.qcvv, name, None)
return None


@dataclass
class Sample:
Expand Down Expand Up @@ -60,6 +87,48 @@ def __hash__(self) -> int:
)
)

def _json_dict_(self) -> dict[str, Any]:
"""Converts the sample to a json-able dictionary that can be used to recreate the
sample object.

Returns:
Json-able dictionary of the sample data.
"""
return {
"circuit": self.circuit,
"data": self.data,
"circuit_realization": self.circuit_realization,
"sample_uuid": str(self.uuid),
}

@classmethod
def _from_json_dict_(
cls,
circuit: cirq.Circuit,
circuit_realization: int,
data: dict[str, Any],
sample_uuid: str,
**_: Any,
) -> Self:
"""Creates a sample from a dictionary of the data.

Args:
dictionary: Dict containing the sample data.

Returns:
The deserialized Sample object.
"""
return cls(
circuit=circuit,
circuit_realization=circuit_realization,
data=data,
uuid=uuid.UUID(sample_uuid),
)

@classmethod
def _json_namespace_(cls) -> str:
return "supermarq.qcvv"


@dataclass
class QCVVResults(ABC):
Expand Down Expand Up @@ -240,6 +309,7 @@ def __init__(
*,
random_seed: int | np.random.Generator | None = None,
results_cls: type[ResultsT],
_samples: Sequence[Sample] | None = None,
**kwargs: Any,
) -> None:
"""Initializes a benchmarking experiment.
Expand All @@ -251,6 +321,7 @@ def __init__(
cycle_depths: A sequence of depths to sample.
random_seed: An optional seed to use for randomization.
results_cls: The results class to use for the experiment.
_samples: Optional list of samples to construct the experiment from
kwargs: Additional kwargs passed to the Superstaq service object.
"""
self.qubits = cirq.LineQubit.range(num_qubits)
Expand All @@ -269,7 +340,10 @@ def __init__(

self._results_cls: type[ResultsT] = results_cls

self.samples = self._prepare_experiment()
if not _samples:
self.samples = self._prepare_experiment()
else:
self.samples = _samples
"""Create all the samples needed for the experiment."""

def __getitem__(self, key: str | int | uuid.UUID) -> Sample:
Expand Down Expand Up @@ -489,6 +563,63 @@ def _map_records_to_samples(

return record_mapping

@abstractmethod
def _json_dict_(self) -> dict[str, Any]:
"""Converts the experiment to a json-able dictionary that can be used to recreate the
experiment object. Note that the state of the random number generator is not stored.

.. note:: Must be re-implemented in any subclasses to ensure all important data is stored.

Returns:
Json-able dictionary of the experiment data.
"""
return {
cdbf1 marked this conversation as resolved.
Show resolved Hide resolved
"cycle_depths": self.cycle_depths,
"num_circuits": self.num_circuits,
"num_qubits": self.num_qubits,
"samples": self.samples,
**self._service_kwargs,
}

@classmethod
@abstractmethod
def _from_json_dict_(cls, *args: Any, **kwargs: Any) -> Self:
"""Creates a experiment from an expanded dictionary of the data.

Returns:
The deserialized experiment object.
"""

@classmethod
def _json_namespace_(cls) -> str:
return "supermarq.qcvv"

def to_file(self, filename: str | pathlib.Path) -> None:
"""Save the experiment to a json file.

Args:
filename: Filename to save to.
"""
with open(filename, "w") as file_stream:
cirq.to_json(self, file_stream)

@classmethod
def from_file(cls, filename: str | pathlib.Path) -> Self:
"""Load the experiment from a json file.

Args:
filename: Filename to load from.

Returns:
The loaded experiment.
"""
with open(filename, "r") as file_stream:
experiment = cirq.read_json(
file_stream,
resolvers=[*css.SUPERSTAQ_RESOLVERS, *cirq.DEFAULT_RESOLVERS, qcvv_resolver],
)
return experiment

def _prepare_experiment(
self,
) -> Sequence[Sample]:
Expand Down
62 changes: 61 additions & 1 deletion supermarq-benchmarks/supermarq/qcvv/base_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import uuid
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, call, patch

import cirq
Expand All @@ -29,12 +30,24 @@
import pandas as pd
import pytest

from supermarq.qcvv.base_experiment import QCVVExperiment, QCVVResults, Sample
from supermarq.qcvv.base_experiment import QCVVExperiment, QCVVResults, Sample, qcvv_resolver

if TYPE_CHECKING:
from typing_extensions import Self

mock_plot = MagicMock()
mock_print = MagicMock()


def test_qcvv_resolver() -> None:
assert qcvv_resolver("bad_name") is None
assert qcvv_resolver("supermarq.qcvv.Sample") == Sample
assert qcvv_resolver("supermarq.qcvv.QCVVExperiment") == QCVVExperiment

# Check for something that is not explicitly exported
assert qcvv_resolver("supermarq.qcvv.base_experiment.qcvv_resolver") is None


@dataclass
class ExampleResults(QCVVResults):
"""Example results class for testing"""
Expand Down Expand Up @@ -67,6 +80,7 @@ def __init__(
cycle_depths: Iterable[int],
*,
random_seed: int | None = None,
_samples: list[Sample] | None = None,
**kwargs: str | bool,
) -> None:
super().__init__(
Expand All @@ -75,6 +89,7 @@ def __init__(
cycle_depths,
random_seed=random_seed,
results_cls=ExampleResults,
_samples=_samples,
**kwargs,
)

Expand All @@ -89,6 +104,27 @@ def _build_circuits(self, num_circuits: int, cycle_depths: Iterable[int]) -> Seq
for d in cycle_depths
]

def _json_dict_(self) -> dict[str, Any]:
return super()._json_dict_()

@classmethod
def _from_json_dict_(
cls,
samples: list[Sample],
num_qubits: int,
num_circuits: int,
cycle_depths: list[int],
**kwargs: Any,
) -> Self:
experiment = cls(
num_circuits=num_circuits,
num_qubits=num_qubits,
cycle_depths=cycle_depths,
_samples=samples,
**kwargs,
)
return experiment


@pytest.fixture
def abc_experiment() -> ExampleExperiment:
Expand Down Expand Up @@ -795,3 +831,27 @@ def test_map_records_to_samples_duplicate_keys(
sample_circuits[1].uuid: {0: 4, 1: 6, 3: 2},
}
)


@patch("supermarq.qcvv.base_experiment.qcvv_resolver")
def test_dump_and_load(
mock_resolver: MagicMock,
tmp_path_factory: pytest.TempPathFactory,
abc_experiment: ExampleExperiment,
sample_circuits: list[Sample],
) -> None:
temp_resolver = {
"supermarq.qcvv.Sample": Sample,
"supermarq.qcvv.ExampleExperiment": ExampleExperiment,
}
mock_resolver.side_effect = lambda x: temp_resolver.get(x)

filename = tmp_path_factory.mktemp("tempdir") / "file.json"
abc_experiment.samples = sample_circuits
abc_experiment.to_file(filename)
exp = ExampleExperiment.from_file(filename)

assert exp.samples == abc_experiment.samples
assert exp.num_qubits == abc_experiment.num_qubits
assert exp.num_circuits == abc_experiment.num_circuits
assert exp.cycle_depths == abc_experiment.cycle_depths
48 changes: 48 additions & 0 deletions supermarq-benchmarks/supermarq/qcvv/irb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import cirq
import cirq.circuits
Expand All @@ -30,6 +31,9 @@

from supermarq.qcvv.base_experiment import QCVVExperiment, QCVVResults, Sample

if TYPE_CHECKING:
from typing_extensions import Self


####################################################################################################
# Some handy functions for 1 and 2 qubit Clifford operations
Expand Down Expand Up @@ -466,6 +470,8 @@ def __init__(
clifford_op_gateset: cirq.CompilationTargetGateset = cirq.CZTargetGateset(),
*,
random_seed: int | np.random.Generator | None = None,
_samples: list[Sample] | None = None,
**kwargs: str,
) -> None:
"""Constructs an IRB experiment.

Expand Down Expand Up @@ -510,6 +516,8 @@ def __init__(
cycle_depths=cycle_depths,
random_seed=random_seed,
results_cls=results_cls,
_samples=_samples,
**kwargs,
)

def _clifford_gate_to_circuit(
Expand Down Expand Up @@ -725,3 +733,43 @@ def _build_circuits(self, num_circuits: int, cycle_depths: Iterable[int]) -> Seq
),
)
return samples

def _json_dict_(self) -> dict[str, Any]:
"""Converts the experiment to a json-able dictionary that can be used to recreate the
experiment object. Note that the state of the random number generator is not stored.

Returns:
Json-able dictionary of the experiment data.
"""
return {
"interleaved_gate": self.interleaved_gate,
"clifford_op_gateset": self.clifford_op_gateset,
**super()._json_dict_(),
}

@classmethod
def _from_json_dict_(
cls,
samples: list[Sample],
interleaved_gate: cirq.Gate,
clifford_op_gateset: cirq.CompilationTargetGateset,
num_circuits: int,
cycle_depths: list[int],
**kwargs: Any,
) -> Self:
"""Creates a experiment from a dictionary of the data.

Args:
dictionary: Dict containing the experiment data.

Returns:
The deserialized experiment object.
"""
return cls(
num_circuits=num_circuits,
cycle_depths=cycle_depths,
clifford_op_gateset=clifford_op_gateset,
interleaved_gate=interleaved_gate,
_samples=samples,
**kwargs,
)
Loading
Loading