From 36a20e715744dfb13a19cf1293f3f13cd249970e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 8 Aug 2024 16:16:32 +0100 Subject: [PATCH 01/26] feat: basic user-facing dataclasses for extensions --- hugr-py/src/hugr/extension.py | 151 ++++++++++++++++++++ hugr-py/src/hugr/serialization/extension.py | 63 +++++++- 2 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 hugr-py/src/hugr/extension.py diff --git a/hugr-py/src/hugr/extension.py b/hugr-py/src/hugr/extension.py new file mode 100644 index 000000000..22f49573d --- /dev/null +++ b/hugr-py/src/hugr/extension.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import hugr.serialization.extension as ext_s +from hugr import tys, val +from hugr.utils import ser_it + +if TYPE_CHECKING: + from semver import Version + + from hugr.hugr import Hugr + from hugr.tys import ExtensionId + + +@dataclass +class ExplicitBound: + bound: tys.TypeBound + + def to_serial(self) -> ext_s.ExplicitBound: + return ext_s.ExplicitBound(bound=self.bound) + + +@dataclass +class FromParamsBound: + indices: list[int] + + def to_serial(self) -> ext_s.FromParamsBound: + return ext_s.FromParamsBound(indices=self.indices) + + +@dataclass +class TypeDef: + extension: ExtensionId + name: str + description: str + params: list[tys.TypeParam] + bound: ExplicitBound | FromParamsBound + + def to_serial(self) -> ext_s.TypeDef: + return ext_s.TypeDef( + extension=self.extension, + name=self.name, + description=self.description, + params=ser_it(self.params), + bound=ext_s.TypeDefBound(root=self.bound.to_serial()), + ) + + +@dataclass +class FixedHugr: + extensions: tys.ExtensionSet + hugr: Hugr + + def to_serial(self) -> ext_s.FixedHugr: + return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr) + + +class OpDefSig: + poly_func: tys.PolyFuncType | None + binary: bool + + def __init__( + self, + poly_func: tys.PolyFuncType | tys.FunctionType | None, + binary: bool = False, + ) -> None: + if poly_func is None and not binary: + msg = ( + "Signature must be provided if binary" + " signature computation is not expected." + ) + raise ValueError(msg) + if isinstance(poly_func, tys.FunctionType): + poly_func = tys.PolyFuncType([], poly_func) + self.poly_func = poly_func + self.binary = binary + + +@dataclass +class OpDef: + name: str + signature: OpDefSig + extension: ExtensionId | None = None + description: str = "" + misc: dict[str, Any] = field(default_factory=dict) + lower_funcs: list[FixedHugr] = field(default_factory=list) + + def to_serial(self) -> ext_s.OpDef: + assert self.extension is not None, "Extension must be initialised." + return ext_s.OpDef( + extension=self.extension, + name=self.name, + description=self.description, + misc=self.misc, + signature=self.signature.poly_func.to_serial() + if self.signature.poly_func + else None, + binary=self.signature.binary, + lower_funcs=[f.to_serial() for f in self.lower_funcs], + ) + + +@dataclass +class ExtensionValue: + extension: ExtensionId + name: str + typed_value: val.Value + + def to_serial(self) -> ext_s.ExtensionValue: + return ext_s.ExtensionValue( + extension=self.extension, + name=self.name, + typed_value=self.typed_value.to_serial_root(), + ) + + +@dataclass +class Extension: + name: ExtensionId + version: Version + extension_reqs: set[ExtensionId] = field(default_factory=set) + types: dict[str, TypeDef] = field(default_factory=dict) + values: dict[str, ExtensionValue] = field(default_factory=dict) + operations: dict[str, OpDef] = field(default_factory=dict) + + def to_serial(self) -> ext_s.Extension: + return ext_s.Extension( + name=self.name, + version=self.version, # type: ignore[arg-type] + extension_reqs=self.extension_reqs, + types={k: v.to_serial() for k, v in self.types.items()}, + values={k: v.to_serial() for k, v in self.values.items()}, + operations={k: v.to_serial() for k, v in self.operations.items()}, + ) + + def add_op_def(self, op_def: OpDef) -> None: + self.operations[op_def.name] = op_def + + +@dataclass +class Package: + modules: list[Hugr] + extensions: list[Extension] = field(default_factory=list) + + def to_serial(self) -> ext_s.Package: + return ext_s.Package( + modules=[m.to_serial() for m in self.modules], + extensions=[e.to_serial() for e in self.extensions], + ) diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/extension.py index fd533e781..4b78ebe24 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/extension.py @@ -1,7 +1,11 @@ -from typing import Annotated, Any, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Any, Literal import pydantic as pd -from pydantic_extra_types.semantic_version import SemanticVersion +from pydantic_extra_types.semantic_version import SemanticVersion # noqa: TCH002 + +from hugr.utils import deser_it from .ops import Value from .serial_hugr import SerialHugr, serialization_version @@ -14,16 +18,26 @@ TypeParam, ) +if TYPE_CHECKING: + from .ops import Value + from .serial_hugr import SerialHugr + class ExplicitBound(ConfiguredBaseModel): b: Literal["Explicit"] = "Explicit" bound: TypeBound + def deserialize(self) -> ext.ExplicitBound: + return ext.ExplicitBound(bound=self.bound) + class FromParamsBound(ConfiguredBaseModel): b: Literal["FromParams"] = "FromParams" indices: list[int] + def deserialize(self) -> ext.FromParamsBound: + return ext.FromParamsBound(indices=self.indices) + class TypeDefBound(pd.RootModel): root: Annotated[ExplicitBound | FromParamsBound, pd.Field(discriminator="b")] @@ -36,12 +50,28 @@ class TypeDef(ConfiguredBaseModel): params: list[TypeParam] bound: TypeDefBound + def deserialize(self) -> ext.TypeDef: + return ext.TypeDef( + extension=self.extension, + name=self.name, + description=self.description, + params=deser_it(self.params), + bound=self.bound.root.deserialize(), + ) + class ExtensionValue(ConfiguredBaseModel): extension: ExtensionId name: str typed_value: Value + def deserialize(self) -> ext.ExtensionValue: + return ext.ExtensionValue( + extension=self.extension, + name=self.name, + typed_value=self.typed_value.deserialize(), + ) + # -------------------------------------- # --------------- OpDef ---------------- @@ -52,6 +82,9 @@ class FixedHugr(ConfiguredBaseModel): extensions: ExtensionSet hugr: Any + def deserialize(self) -> ext.FixedHugr: + return ext.FixedHugr(extensions=self.extensions, hugr=self.hugr) + class OpDef(ConfiguredBaseModel, populate_by_name=True): """Serializable definition for dynamically loaded operations.""" @@ -61,8 +94,21 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): description: str # Human readable description of the operation. misc: dict[str, Any] | None = None signature: PolyFuncType | None = None + binary: bool = False lower_funcs: list[FixedHugr] + def deserialize(self) -> ext.OpDef: + return ext.OpDef( + extension=self.extension, + name=self.name, + description=self.description, + misc=self.misc or {}, + signature=ext.OpDefSig( + self.signature.deserialize() if self.signature else None, self.binary + ), + lower_funcs=[f.deserialize() for f in self.lower_funcs], + ) + class Extension(ConfiguredBaseModel): version: SemanticVersion @@ -76,6 +122,16 @@ class Extension(ConfiguredBaseModel): def get_version(cls) -> str: return serialization_version() + def deserialize(self) -> ext.Extension: + return ext.Extension( + version=self.version, # type: ignore[arg-type] + name=self.name, + extension_reqs=self.extension_reqs, + types={k: v.deserialize() for k, v in self.types.items()}, + values={k: v.deserialize() for k, v in self.values.items()}, + operations={k: v.deserialize() for k, v in self.operations.items()}, + ) + class Package(ConfiguredBaseModel): modules: list[SerialHugr] @@ -84,3 +140,6 @@ class Package(ConfiguredBaseModel): @classmethod def get_version(cls) -> str: return serialization_version() + + +import hugr.extension as ext # noqa: E402 From 5ba757bf35ae200d419240eea60dd9c95d0876f3 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 8 Aug 2024 17:58:57 +0100 Subject: [PATCH 02/26] make extension optional in definition classes --- hugr-py/src/hugr/extension.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/hugr-py/src/hugr/extension.py b/hugr-py/src/hugr/extension.py index 22f49573d..9d0a5bbbf 100644 --- a/hugr-py/src/hugr/extension.py +++ b/hugr-py/src/hugr/extension.py @@ -21,6 +21,9 @@ class ExplicitBound: def to_serial(self) -> ext_s.ExplicitBound: return ext_s.ExplicitBound(bound=self.bound) + def to_serial_root(self) -> ext_s.TypeDefBound: + return ext_s.TypeDefBound(root=self.to_serial()) + @dataclass class FromParamsBound: @@ -29,16 +32,20 @@ class FromParamsBound: def to_serial(self) -> ext_s.FromParamsBound: return ext_s.FromParamsBound(indices=self.indices) + def to_serial_root(self) -> ext_s.TypeDefBound: + return ext_s.TypeDefBound(root=self.to_serial()) + @dataclass class TypeDef: - extension: ExtensionId name: str description: str params: list[tys.TypeParam] bound: ExplicitBound | FromParamsBound + extension: ExtensionId | None = None def to_serial(self) -> ext_s.TypeDef: + assert self.extension is not None, "Extension must be initialised." return ext_s.TypeDef( extension=self.extension, name=self.name, @@ -104,11 +111,12 @@ def to_serial(self) -> ext_s.OpDef: @dataclass class ExtensionValue: - extension: ExtensionId name: str typed_value: val.Value + extension: ExtensionId | None = None def to_serial(self) -> ext_s.ExtensionValue: + assert self.extension is not None, "Extension must be initialised." return ext_s.ExtensionValue( extension=self.extension, name=self.name, @@ -138,6 +146,12 @@ def to_serial(self) -> ext_s.Extension: def add_op_def(self, op_def: OpDef) -> None: self.operations[op_def.name] = op_def + def add_type_def(self, type_def: TypeDef) -> None: + self.types[type_def.name] = type_def + + def add_extension_value(self, extension_value: ExtensionValue) -> None: + self.values[extension_value.name] = extension_value + @dataclass class Package: From 755a1a9382dc6e0049c0a6d204bc1c065926e574 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 9 Aug 2024 11:27:56 +0100 Subject: [PATCH 03/26] refactor(hugr-py): rename extension to ext --- hugr-py/src/hugr/{extension.py => ext.py} | 2 +- hugr-py/src/hugr/serialization/{extension.py => ext.py} | 2 +- hugr-py/src/hugr/serialization/testing_hugr.py | 2 +- hugr-py/tests/serialization/test_extension.py | 2 +- scripts/generate_schema.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename hugr-py/src/hugr/{extension.py => ext.py} (99%) rename hugr-py/src/hugr/serialization/{extension.py => ext.py} (98%) diff --git a/hugr-py/src/hugr/extension.py b/hugr-py/src/hugr/ext.py similarity index 99% rename from hugr-py/src/hugr/extension.py rename to hugr-py/src/hugr/ext.py index 9d0a5bbbf..05a212dfe 100644 --- a/hugr-py/src/hugr/extension.py +++ b/hugr-py/src/hugr/ext.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any -import hugr.serialization.extension as ext_s +import hugr.serialization.ext as ext_s from hugr import tys, val from hugr.utils import ser_it diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/ext.py similarity index 98% rename from hugr-py/src/hugr/serialization/extension.py rename to hugr-py/src/hugr/serialization/ext.py index 4b78ebe24..0b24b714c 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/ext.py @@ -142,4 +142,4 @@ def get_version(cls) -> str: return serialization_version() -import hugr.extension as ext # noqa: E402 +from hugr import ext # noqa: E402 diff --git a/hugr-py/src/hugr/serialization/testing_hugr.py b/hugr-py/src/hugr/serialization/testing_hugr.py index fa57ece7a..f90de29f7 100644 --- a/hugr-py/src/hugr/serialization/testing_hugr.py +++ b/hugr-py/src/hugr/serialization/testing_hugr.py @@ -1,6 +1,6 @@ from pydantic import ConfigDict -from .extension import OpDef +from .ext import OpDef from .ops import OpType, Value from .ops import classes as ops_classes from .serial_hugr import VersionField diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index 6d4425b7d..85200aee7 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -1,6 +1,6 @@ from semver import Version -from hugr.serialization.extension import ( +from hugr.serialization.ext import ( ExplicitBound, Extension, OpDef, diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index 908bdfa1a..cf593512a 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -16,7 +16,7 @@ from pydantic import ConfigDict from pydantic.json_schema import models_json_schema -from hugr.serialization.extension import Extension, Package +from hugr.serialization.ext import Extension, Package from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.testing_hugr import TestingHugr From 59ef88dd01b119fcb0d2833403c8605026d4771e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 9 Aug 2024 13:58:31 +0100 Subject: [PATCH 04/26] separate resolved and unresolved types --- hugr-py/src/hugr/ext.py | 35 +++++++++++++++---------- hugr-py/src/hugr/serialization/ext.py | 24 ++++++++++------- hugr-py/src/hugr/std/float.py | 20 ++++++++++----- hugr-py/src/hugr/std/int.py | 24 +++++++++++------ hugr-py/src/hugr/tys.py | 37 ++++++++++++++++++++++++++- hugr-py/tests/test_custom.py | 1 + 6 files changed, 103 insertions(+), 38 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 05a212dfe..bca324ce0 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -1,3 +1,5 @@ +"""HUGR extensions and packages.""" + from __future__ import annotations from dataclasses import dataclass, field @@ -8,6 +10,8 @@ from hugr.utils import ser_it if TYPE_CHECKING: + from collections.abc import Sequence + from semver import Version from hugr.hugr import Hugr @@ -15,7 +19,7 @@ @dataclass -class ExplicitBound: +class ExplicitBound: # noqa: D101 bound: tys.TypeBound def to_serial(self) -> ext_s.ExplicitBound: @@ -26,7 +30,7 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class FromParamsBound: +class FromParamsBound: # noqa: D101 indices: list[int] def to_serial(self) -> ext_s.FromParamsBound: @@ -37,26 +41,29 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class TypeDef: +class TypeDef: # noqa: D101 name: str description: str params: list[tys.TypeParam] bound: ExplicitBound | FromParamsBound - extension: ExtensionId | None = None + _extension: Extension | None = field(default=None, init=False) def to_serial(self) -> ext_s.TypeDef: - assert self.extension is not None, "Extension must be initialised." + assert self._extension is not None, "Extension must be initialised." return ext_s.TypeDef( - extension=self.extension, + extension=self._extension.name, name=self.name, description=self.description, params=ser_it(self.params), bound=ext_s.TypeDefBound(root=self.bound.to_serial()), ) + def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType: + return tys.ExtType(self, list(args)) + @dataclass -class FixedHugr: +class FixedHugr: # noqa: D101 extensions: tys.ExtensionSet hugr: Hugr @@ -64,7 +71,7 @@ def to_serial(self) -> ext_s.FixedHugr: return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr) -class OpDefSig: +class OpDefSig: # noqa: D101 poly_func: tys.PolyFuncType | None binary: bool @@ -86,7 +93,7 @@ def __init__( @dataclass -class OpDef: +class OpDef: # noqa: D101 name: str signature: OpDefSig extension: ExtensionId | None = None @@ -110,7 +117,7 @@ def to_serial(self) -> ext_s.OpDef: @dataclass -class ExtensionValue: +class ExtensionValue: # noqa: D101 name: str typed_value: val.Value extension: ExtensionId | None = None @@ -125,7 +132,7 @@ def to_serial(self) -> ext_s.ExtensionValue: @dataclass -class Extension: +class Extension: # noqa: D101 name: ExtensionId version: Version extension_reqs: set[ExtensionId] = field(default_factory=set) @@ -146,15 +153,17 @@ def to_serial(self) -> ext_s.Extension: def add_op_def(self, op_def: OpDef) -> None: self.operations[op_def.name] = op_def - def add_type_def(self, type_def: TypeDef) -> None: + def add_type_def(self, type_def: TypeDef) -> TypeDef: + type_def._extension = self self.types[type_def.name] = type_def + return self.types[type_def.name] def add_extension_value(self, extension_value: ExtensionValue) -> None: self.values[extension_value.name] = extension_value @dataclass -class Package: +class Package: # noqa: D101 modules: list[Hugr] extensions: list[Extension] = field(default_factory=list) diff --git a/hugr-py/src/hugr/serialization/ext.py b/hugr-py/src/hugr/serialization/ext.py index 0b24b714c..193050b17 100644 --- a/hugr-py/src/hugr/serialization/ext.py +++ b/hugr-py/src/hugr/serialization/ext.py @@ -50,13 +50,14 @@ class TypeDef(ConfiguredBaseModel): params: list[TypeParam] bound: TypeDefBound - def deserialize(self) -> ext.TypeDef: - return ext.TypeDef( - extension=self.extension, - name=self.name, - description=self.description, - params=deser_it(self.params), - bound=self.bound.root.deserialize(), + def deserialize(self, extension: ext.Extension) -> ext.TypeDef: + return extension.add_type_def( + ext.TypeDef( + name=self.name, + description=self.description, + params=deser_it(self.params), + bound=self.bound.root.deserialize(), + ) ) @@ -123,15 +124,20 @@ def get_version(cls) -> str: return serialization_version() def deserialize(self) -> ext.Extension: - return ext.Extension( + e = ext.Extension( version=self.version, # type: ignore[arg-type] name=self.name, extension_reqs=self.extension_reqs, - types={k: v.deserialize() for k, v in self.types.items()}, + # types={k: v.deserialize() for k, v in self.types.items()}, values={k: v.deserialize() for k, v in self.values.items()}, operations={k: v.deserialize() for k, v in self.operations.items()}, ) + for v in self.types.values(): + e.add_type_def(v.deserialize(e)) + + return e + class Package(ConfiguredBaseModel): modules: list[SerialHugr] diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py index bf0cfff18..601b72f93 100644 --- a/hugr-py/src/hugr/std/float.py +++ b/hugr-py/src/hugr/std/float.py @@ -4,17 +4,23 @@ from dataclasses import dataclass -from hugr import tys, val +from semver import Version +from hugr import ext, tys, val + +Extension = ext.Extension("arithmetic.float.types", Version(0, 1, 0)) #: HUGR 64-bit IEEE 754-2019 floating point type. -FLOAT_EXT_ID = "arithmetic.float.types" -FLOAT_T = tys.Opaque( - extension=FLOAT_EXT_ID, - id="float64", - args=[], - bound=tys.TypeBound.Copyable, +FLOAT_T_DEF = Extension.add_type_def( + ext.TypeDef( + name="float64", + description="64-bit IEEE 754-2019 floating point number", + params=[], + bound=ext.ExplicitBound(tys.TypeBound.Copyable), + ) ) +FLOAT_T = FLOAT_T_DEF.instantiate([]) + @dataclass class FloatVal(val.ExtensionValue): diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 722fd3c24..d3123791d 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -5,16 +5,27 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, ClassVar +from semver import Version from typing_extensions import Self -from hugr import tys, val +from hugr import ext, tys, val from hugr.ops import AsCustomOp, Custom, DataflowOp if TYPE_CHECKING: from hugr.ops import Command, ComWire +EXTENSION = ext.Extension("arithmetic.int.types", Version(0, 1, 0)) +INT_T_DEF = EXTENSION.add_type_def( + ext.TypeDef( + name="int", + description="Variable width integer.", + bound=ext.ExplicitBound(tys.TypeBound.Copyable), + params=[tys.BoundedNatParam(7)], + ) +) + -def int_t(width: int) -> tys.Opaque: +def int_t(width: int) -> tys.ExtType: """Create an integer type with a fixed log bit width. @@ -25,14 +36,11 @@ def int_t(width: int) -> tys.Opaque: The integer type. Examples: - >>> int_t(5).id # 32 bit integer + >>> int_t(5).type_def.name # 32 bit integer 'int' """ - return tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=width)], - bound=tys.TypeBound.Copyable, + return INT_T_DEF.instantiate( + [tys.BoundedNatArg(n=width)], ) diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index ee07bf626..222e1fcfc 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -3,11 +3,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, runtime_checkable import hugr.serialization.tys as stys from hugr.utils import ser_it +if TYPE_CHECKING: + from hugr.ext import TypeDef + ExtensionId = stys.ExtensionId ExtensionSet = stys.ExtensionSet TypeBound = stys.TypeBound @@ -403,6 +406,38 @@ def to_serial(self) -> stys.PolyFuncType: ) +@dataclass +class ExtType(Type): + """Extension type, defined by a type definition and type arguments.""" + + type_def: TypeDef + args: list[TypeArg] = field(default_factory=list) + + def type_bound(self) -> TypeBound: + from hugr.ext import ExplicitBound, FromParamsBound + + match self.type_def.bound: + case ExplicitBound(exp_bound): + return exp_bound + case FromParamsBound(indices): + bounds: list[TypeBound] = [] + for idx in indices: + arg = self.args[idx] + if isinstance(arg, TypeTypeArg): + bounds.append(arg.ty.type_bound()) + return TypeBound.join(*bounds) + + def to_serial(self) -> stys.Opaque: + assert self.type_def._extension is not None, "Extension must be initialised." + + return stys.Opaque( + extension=self.type_def._extension.name, + id=self.type_def.name, + args=[arg.to_serial_root() for arg in self.args], + bound=self.type_bound(), + ) + + @dataclass class Opaque(Type): """Opaque type, identified by `id` and with optional type arguments and bound.""" diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 069c630c9..f20d65786 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -46,6 +46,7 @@ def test_stringly_typed(): validate(dfg.hugr) +@pytest.mark.xfail(reason="Extension resolution not implemented yet.") @pytest.mark.parametrize( "as_custom", [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], From 5afd4cc7370c7919112030c07765db7d915d5f99 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 9 Aug 2024 14:04:07 +0100 Subject: [PATCH 05/26] point to extension in opdef and extval --- hugr-py/src/hugr/ext.py | 20 +++++++----- hugr-py/src/hugr/serialization/ext.py | 46 +++++++++++++++------------ 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index bca324ce0..14232ba2e 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -96,15 +96,15 @@ def __init__( class OpDef: # noqa: D101 name: str signature: OpDefSig - extension: ExtensionId | None = None description: str = "" misc: dict[str, Any] = field(default_factory=dict) lower_funcs: list[FixedHugr] = field(default_factory=list) + _extension: Extension | None = field(default=None, init=False) def to_serial(self) -> ext_s.OpDef: - assert self.extension is not None, "Extension must be initialised." + assert self._extension is not None, "Extension must be initialised." return ext_s.OpDef( - extension=self.extension, + extension=self._extension.name, name=self.name, description=self.description, misc=self.misc, @@ -120,12 +120,12 @@ def to_serial(self) -> ext_s.OpDef: class ExtensionValue: # noqa: D101 name: str typed_value: val.Value - extension: ExtensionId | None = None + _extension: Extension | None = field(default=None, init=False) def to_serial(self) -> ext_s.ExtensionValue: - assert self.extension is not None, "Extension must be initialised." + assert self._extension is not None, "Extension must be initialised." return ext_s.ExtensionValue( - extension=self.extension, + extension=self._extension.name, name=self.name, typed_value=self.typed_value.to_serial_root(), ) @@ -150,16 +150,20 @@ def to_serial(self) -> ext_s.Extension: operations={k: v.to_serial() for k, v in self.operations.items()}, ) - def add_op_def(self, op_def: OpDef) -> None: + def add_op_def(self, op_def: OpDef) -> OpDef: + op_def._extension = self self.operations[op_def.name] = op_def + return self.operations[op_def.name] def add_type_def(self, type_def: TypeDef) -> TypeDef: type_def._extension = self self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> None: + def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: + extension_value._extension = self self.values[extension_value.name] = extension_value + return self.values[extension_value.name] @dataclass diff --git a/hugr-py/src/hugr/serialization/ext.py b/hugr-py/src/hugr/serialization/ext.py index 193050b17..ff90dfa80 100644 --- a/hugr-py/src/hugr/serialization/ext.py +++ b/hugr-py/src/hugr/serialization/ext.py @@ -66,11 +66,12 @@ class ExtensionValue(ConfiguredBaseModel): name: str typed_value: Value - def deserialize(self) -> ext.ExtensionValue: - return ext.ExtensionValue( - extension=self.extension, - name=self.name, - typed_value=self.typed_value.deserialize(), + def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: + return extension.add_extension_value( + ext.ExtensionValue( + name=self.name, + typed_value=self.typed_value.deserialize(), + ) ) @@ -98,16 +99,18 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): binary: bool = False lower_funcs: list[FixedHugr] - def deserialize(self) -> ext.OpDef: - return ext.OpDef( - extension=self.extension, - name=self.name, - description=self.description, - misc=self.misc or {}, - signature=ext.OpDefSig( - self.signature.deserialize() if self.signature else None, self.binary - ), - lower_funcs=[f.deserialize() for f in self.lower_funcs], + def deserialize(self, extension: ext.Extension) -> ext.OpDef: + return extension.add_op_def( + ext.OpDef( + name=self.name, + description=self.description, + misc=self.misc or {}, + signature=ext.OpDefSig( + self.signature.deserialize() if self.signature else None, + self.binary, + ), + lower_funcs=[f.deserialize() for f in self.lower_funcs], + ) ) @@ -128,13 +131,16 @@ def deserialize(self) -> ext.Extension: version=self.version, # type: ignore[arg-type] name=self.name, extension_reqs=self.extension_reqs, - # types={k: v.deserialize() for k, v in self.types.items()}, - values={k: v.deserialize() for k, v in self.values.items()}, - operations={k: v.deserialize() for k, v in self.operations.items()}, ) - for v in self.types.values(): - e.add_type_def(v.deserialize(e)) + for t in self.types.values(): + e.add_type_def(t.deserialize(e)) + + for o in self.operations.values(): + e.add_op_def(o.deserialize(e)) + + for v in self.values.values(): + e.add_extension_value(v.deserialize(e)) return e From 0b9b2a6e82e4c6552e1a98d6b611dd9a89253588 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 9 Aug 2024 16:36:16 +0100 Subject: [PATCH 06/26] use extension objects in std extension defs --- hugr-py/src/hugr/ext.py | 31 +++++++++++++--- hugr-py/src/hugr/ops.py | 32 +++++++++++++++- hugr-py/src/hugr/std/float.py | 4 +- hugr-py/src/hugr/std/int.py | 47 ++++++++++++++++-------- hugr-py/src/hugr/std/logic.py | 22 +++++++---- hugr-py/src/hugr/val.py | 3 ++ hugr-py/tests/conftest.py | 69 ++++++++++++++++++++++------------- hugr-py/tests/test_custom.py | 11 +++--- 8 files changed, 154 insertions(+), 65 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 14232ba2e..f0b9e24d5 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -5,15 +5,28 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +from semver import Version + import hugr.serialization.ext as ext_s from hugr import tys, val from hugr.utils import ser_it +__all__ = [ + "ExplicitBound", + "FromParamsBound", + "TypeDef", + "FixedHugr", + "OpDefSig", + "OpDef", + "ExtensionValue", + "Extension", + "Package", + "Version", +] + if TYPE_CHECKING: from collections.abc import Sequence - from semver import Version - from hugr.hugr import Hugr from hugr.tys import ExtensionId @@ -46,7 +59,9 @@ class TypeDef: # noqa: D101 description: str params: list[tys.TypeParam] bound: ExplicitBound | FromParamsBound - _extension: Extension | None = field(default=None, init=False) + _extension: Extension | None = field( + default=None, init=False, repr=False, compare=False + ) def to_serial(self) -> ext_s.TypeDef: assert self._extension is not None, "Extension must be initialised." @@ -98,8 +113,10 @@ class OpDef: # noqa: D101 signature: OpDefSig description: str = "" misc: dict[str, Any] = field(default_factory=dict) - lower_funcs: list[FixedHugr] = field(default_factory=list) - _extension: Extension | None = field(default=None, init=False) + lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False) + _extension: Extension | None = field( + default=None, init=False, repr=False, compare=False + ) def to_serial(self) -> ext_s.OpDef: assert self._extension is not None, "Extension must be initialised." @@ -120,7 +137,9 @@ def to_serial(self) -> ext_s.OpDef: class ExtensionValue: # noqa: D101 name: str typed_value: val.Value - _extension: Extension | None = field(default=None, init=False) + _extension: Extension | None = field( + default=None, init=False, repr=False, compare=False + ) def to_serial(self) -> ext_s.ExtensionValue: assert self._extension is not None, "Extension must be initialised." diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 61621c5f7..a01f0c516 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -9,7 +9,7 @@ from typing_extensions import Self import hugr.serialization.ops as sops -from hugr import tys, val +from hugr import ext, tys, val from hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire from hugr.utils import ser_it @@ -303,6 +303,36 @@ def check_id(self, extension: tys.ExtensionId, name: str) -> bool: return self.extension == extension and self.name == name +@dataclass(frozen=True) +class ExtOp(AsCustomOp): + """A non-core dataflow operation defined in an extension.""" + + op_def: ext.OpDef + signature: tys.FunctionType | None = None + args: list[tys.TypeArg] = field(default_factory=list) + + def to_custom(self) -> Custom: + ext = self.op_def._extension + if self.signature is None: + poly_func = self.op_def.signature.poly_func + if poly_func is None or len(poly_func.params) > 0: + msg = "For polymorphic ops signature must be cached." + raise ValueError(msg) + sig = poly_func.body + else: + sig = self.signature + + return Custom( + name=self.op_def.name, + signature=sig, + extension=ext.name if ext else "", + args=self.args, + ) + + def to_serial(self, parent: Node) -> sops.CustomOp: + return self.to_custom().to_serial(parent) + + @dataclass() class MakeTuple(DataflowOp, _PartialOp): """Operation to create a tuple from a sequence of wires.""" diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py index 601b72f93..76c62b78a 100644 --- a/hugr-py/src/hugr/std/float.py +++ b/hugr-py/src/hugr/std/float.py @@ -4,11 +4,9 @@ from dataclasses import dataclass -from semver import Version - from hugr import ext, tys, val -Extension = ext.Extension("arithmetic.float.types", Version(0, 1, 0)) +Extension = ext.Extension("arithmetic.float.types", ext.Version(0, 1, 0)) #: HUGR 64-bit IEEE 754-2019 floating point type. FLOAT_T_DEF = Extension.add_type_def( ext.TypeDef( diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index d3123791d..a1438df9c 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -3,24 +3,24 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING -from semver import Version from typing_extensions import Self from hugr import ext, tys, val -from hugr.ops import AsCustomOp, Custom, DataflowOp +from hugr.ops import AsCustomOp, Custom, DataflowOp, ExtOp if TYPE_CHECKING: from hugr.ops import Command, ComWire -EXTENSION = ext.Extension("arithmetic.int.types", Version(0, 1, 0)) -INT_T_DEF = EXTENSION.add_type_def( +TYPES_EXTENSION = ext.Extension("arithmetic.int.types", ext.Version(0, 1, 0)) +_INT_PARAM = tys.BoundedNatParam(7) +INT_T_DEF = TYPES_EXTENSION.add_type_def( ext.TypeDef( name="int", description="Variable width integer.", bound=ext.ExplicitBound(tys.TypeBound.Copyable), - params=[tys.BoundedNatParam(7)], + params=[_INT_PARAM], ) ) @@ -44,6 +44,12 @@ def int_t(width: int) -> tys.ExtType: ) +def _int_tv(index: int) -> tys.ExtType: + return INT_T_DEF.instantiate( + [tys.VariableArg(idx=index, param=_INT_PARAM)], + ) + + #: HUGR 32-bit integer type. INT_T = int_t(5) @@ -59,30 +65,39 @@ def to_value(self) -> val.Extension: return val.Extension("int", int_t(self.width), self.v) -OPS_EXTENSION: tys.ExtensionId = "arithmetic.int" +OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0)) + +_DivMod = OPS_EXTENSION.add_op_def( + ext.OpDef( + name="idivmod_u", + description="Unsigned integer division and modulo.", + signature=ext.OpDefSig( + tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)]) + ), + ) +) @dataclass(frozen=True) class _DivModDef(AsCustomOp): """DivMod operation, has two inputs and two outputs.""" - name: ClassVar[str] = "idivmod_u" arg1: int = 5 arg2: int = 5 + op_def: ext.OpDef = field(default_factory=lambda: _DivMod, init=False) def to_custom(self) -> Custom: - return Custom( - "idivmod_u", - tys.FunctionType( - input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2 - ), - extension=OPS_EXTENSION, - args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)], + row: list[tys.Type] = [int_t(self.arg1), int_t(self.arg2)] + ext_op = ExtOp( + self.op_def, + tys.FunctionType.endo(row), + [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)], ) + return ext_op.to_custom() @classmethod def from_custom(cls, custom: Custom) -> Self | None: - if not custom.check_id(OPS_EXTENSION, "idivmod_u"): + if not custom.check_id(OPS_EXTENSION.name, _DivMod.name): return None match custom.args: case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py index 1291a61c5..cea3bd80d 100644 --- a/hugr-py/src/hugr/std/logic.py +++ b/hugr-py/src/hugr/std/logic.py @@ -5,26 +5,32 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from hugr import tys -from hugr.ops import AsCustomOp, Command, Custom, DataflowOp +from hugr import ext, tys +from hugr.ops import AsCustomOp, Command, Custom, DataflowOp, ExtOp if TYPE_CHECKING: from hugr.ops import ComWire -EXTENSION_ID: tys.ExtensionId = "logic" +EXTENSION = ext.Extension("logic", ext.Version(0, 1, 0)) +_NotDef = EXTENSION.add_op_def( + ext.OpDef( + name="Not", + description="Logical NOT operation.", + signature=ext.OpDefSig(tys.FunctionType.endo([tys.Bool])), + ) +) -@dataclass(frozen=True) -class _NotDef(AsCustomOp): - """Not operation.""" +@dataclass(frozen=True) +class _NotOp(AsCustomOp): def to_custom(self) -> Custom: - return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID) + return ExtOp(_NotDef).to_custom() def __call__(self, a: ComWire) -> Command: return DataflowOp.__call__(self, a) #: Not operation -Not = _NotDef() +Not = _NotOp() diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 7f0a8f249..886893d00 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -199,3 +199,6 @@ def type_(self) -> tys.Type: def to_serial(self) -> sops.ExtensionValue: return self.to_value().to_serial() + + +# TODO extension value that points to an extension. diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index d54302a42..5c3b40dfc 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -10,23 +10,58 @@ from typing_extensions import Self -import hugr.tys as tys +from hugr import ext, tys from hugr.hugr import Hugr -from hugr.ops import AsCustomOp, Command, Custom, DataflowOp +from hugr.ops import AsCustomOp, Command, Custom, DataflowOp, ExtOp from hugr.serialization.serial_hugr import SerialHugr from hugr.std.float import FLOAT_T if TYPE_CHECKING: from hugr.ops import ComWire - -QUANTUM_EXTENSION_ID: tys.ExtensionId = "quantum.tket2" +EXTENSION = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0)) +_SINGLE_QUBIT = ext.OpDefSig(tys.FunctionType.endo([tys.Qubit])) +_TWO_QUBIT = ext.OpDefSig(tys.FunctionType.endo([tys.Qubit] * 2)) +_MEAS_SIG = ext.OpDefSig(tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])) +_RZ_SIG = ext.OpDefSig(tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])) + +EXTENSION.add_op_def( + ext.OpDef( + name="H", + description="Hadamard gate", + signature=_SINGLE_QUBIT, + ) +) + +EXTENSION.add_op_def( + ext.OpDef( + name="CX", + description="CNOT gate", + signature=_TWO_QUBIT, + ) +) + +EXTENSION.add_op_def( + ext.OpDef( + name="Measure", + description="Measurement operation", + signature=_MEAS_SIG, + ) +) + +EXTENSION.add_op_def( + ext.OpDef( + name="Rz", + description="Rotation around the z-axis", + signature=_RZ_SIG, + ) +) E = TypeVar("E", bound=Enum) def _load_enum(enum_cls: type[E], custom: Custom) -> E | None: - if custom.extension == QUANTUM_EXTENSION_ID and custom.name in enum_cls.__members__: + if custom.extension == EXTENSION.name and custom.name in enum_cls.__members__: return enum_cls(custom.name) return None @@ -43,11 +78,7 @@ def __call__(self, q: ComWire) -> Command: return DataflowOp.__call__(self, q) def to_custom(self) -> Custom: - return Custom( - self._enum.value, - tys.FunctionType.endo([tys.Qubit]), - extension=QUANTUM_EXTENSION_ID, - ) + return ExtOp(EXTENSION.operations[self._enum.value]).to_custom() @classmethod def from_custom(cls, custom: Custom) -> Self | None: @@ -65,11 +96,7 @@ class _Enum(Enum): _enum: _Enum def to_custom(self) -> Custom: - return Custom( - self._enum.value, - tys.FunctionType.endo([tys.Qubit] * 2), - extension=QUANTUM_EXTENSION_ID, - ) + return ExtOp(EXTENSION.operations[self._enum.value]).to_custom() @classmethod def from_custom(cls, custom: Custom) -> Self | None: @@ -85,11 +112,7 @@ def __call__(self, q0: ComWire, q1: ComWire) -> Command: @dataclass(frozen=True) class MeasureDef(AsCustomOp): def to_custom(self) -> Custom: - return Custom( - "Measure", - tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]), - extension=QUANTUM_EXTENSION_ID, - ) + return ExtOp(EXTENSION.operations["Measure"]).to_custom() def __call__(self, q: ComWire) -> Command: return super().__call__(q) @@ -101,11 +124,7 @@ def __call__(self, q: ComWire) -> Command: @dataclass(frozen=True) class RzDef(AsCustomOp): def to_custom(self) -> Custom: - return Custom( - "Rz", - tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]), - extension=QUANTUM_EXTENSION_ID, - ) + return ExtOp(EXTENSION.operations["Rz"]).to_custom() def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: return super().__call__(q, fl_wire) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index f20d65786..a27aba171 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -4,10 +4,9 @@ from hugr import tys from hugr.dfg import Dfg -from hugr.node_port import Node from hugr.ops import AsCustomOp, Custom from hugr.std.int import DivMod -from hugr.std.logic import EXTENSION_ID, Not +from hugr.std.logic import EXTENSION, Not from .conftest import CX, H, Measure, Rz, validate @@ -46,7 +45,6 @@ def test_stringly_typed(): validate(dfg.hugr) -@pytest.mark.xfail(reason="Extension resolution not implemented yet.") @pytest.mark.parametrize( "as_custom", [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], @@ -58,7 +56,8 @@ def test_custom(as_custom: AsCustomOp): assert Custom.from_custom(custom) == custom assert type(as_custom).from_custom(custom) == as_custom - assert as_custom.to_serial(Node(0)).deserialize() == custom + # TODO extension resolution needed for this equality + # assert as_custom.to_serial(Node(0)).deserialize() == custom assert custom == as_custom assert as_custom == custom @@ -66,13 +65,13 @@ def test_custom(as_custom: AsCustomOp): def test_custom_bad_eq(): assert Not != DivMod - bad_custom_sig = Custom("Not", extension=EXTENSION_ID) # empty signature + bad_custom_sig = Custom("Not", extension=EXTENSION.name) # empty signature assert Not != bad_custom_sig bad_custom_args = Custom( "Not", - extension=EXTENSION_ID, + extension=EXTENSION.name, signature=tys.FunctionType.endo([tys.Bool]), args=[tys.Bool.type_arg()], ) From 73f924280ed1de10d86bfb4ca40d5717f9b1b09b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 9 Aug 2024 17:06:06 +0100 Subject: [PATCH 07/26] undo serialized extension rename --- hugr-py/src/hugr/ext.py | 2 +- hugr-py/src/hugr/serialization/{ext.py => extension.py} | 0 hugr-py/src/hugr/serialization/testing_hugr.py | 2 +- hugr-py/tests/serialization/test_extension.py | 2 +- scripts/generate_schema.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename hugr-py/src/hugr/serialization/{ext.py => extension.py} (100%) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index f0b9e24d5..cc966318e 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -7,7 +7,7 @@ from semver import Version -import hugr.serialization.ext as ext_s +import hugr.serialization.extension as ext_s from hugr import tys, val from hugr.utils import ser_it diff --git a/hugr-py/src/hugr/serialization/ext.py b/hugr-py/src/hugr/serialization/extension.py similarity index 100% rename from hugr-py/src/hugr/serialization/ext.py rename to hugr-py/src/hugr/serialization/extension.py diff --git a/hugr-py/src/hugr/serialization/testing_hugr.py b/hugr-py/src/hugr/serialization/testing_hugr.py index f90de29f7..fa57ece7a 100644 --- a/hugr-py/src/hugr/serialization/testing_hugr.py +++ b/hugr-py/src/hugr/serialization/testing_hugr.py @@ -1,6 +1,6 @@ from pydantic import ConfigDict -from .ext import OpDef +from .extension import OpDef from .ops import OpType, Value from .ops import classes as ops_classes from .serial_hugr import VersionField diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index 85200aee7..6d4425b7d 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -1,6 +1,6 @@ from semver import Version -from hugr.serialization.ext import ( +from hugr.serialization.extension import ( ExplicitBound, Extension, OpDef, diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index cf593512a..908bdfa1a 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -16,7 +16,7 @@ from pydantic import ConfigDict from pydantic.json_schema import models_json_schema -from hugr.serialization.ext import Extension, Package +from hugr.serialization.extension import Extension, Package from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.testing_hugr import TestingHugr From 844f86ce7284fbccea1179b1a5ac8adf65a1f54d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 14:15:40 +0100 Subject: [PATCH 08/26] feat: resolve custom ops and types to extensions --- hugr-py/src/hugr/ext.py | 62 ++++++++++++++++++++++++++++++++ hugr-py/src/hugr/ops.py | 9 +++++ hugr-py/src/hugr/std/float.py | 4 +-- hugr-py/src/hugr/tys.py | 40 +++++++++++++++++++-- hugr-py/tests/test_custom.py | 66 +++++++++++++++++++++++++++-------- 5 files changed, 163 insertions(+), 18 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index cc966318e..4be1341bd 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -159,6 +159,12 @@ class Extension: # noqa: D101 values: dict[str, ExtensionValue] = field(default_factory=dict) operations: dict[str, OpDef] = field(default_factory=dict) + @dataclass + class NotFound(Exception): + """An object was not found in the extension.""" + + name: str + def to_serial(self) -> ext_s.Extension: return ext_s.Extension( name=self.name, @@ -184,6 +190,62 @@ def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue self.values[extension_value.name] = extension_value return self.values[extension_value.name] + @dataclass + class OperationNotFound(NotFound): + """Operation not found in extension.""" + + def get_op(self, name: str) -> OpDef: + try: + return self.operations[name] + except KeyError as e: + raise self.OperationNotFound(name) from e + + @dataclass + class TypeNotFound(NotFound): + """Type not found in extension.""" + + def get_type(self, name: str) -> TypeDef: + try: + return self.types[name] + except KeyError as e: + raise self.TypeNotFound(name) from e + + @dataclass + class ValueNotFound(NotFound): + """Value not found in extension.""" + + def get_value(self, name: str) -> ExtensionValue: + try: + return self.values[name] + except KeyError as e: + raise self.ValueNotFound(name) from e + + +@dataclass +class ExtensionRegistry: + extensions: dict[ExtensionId, Extension] = field(default_factory=dict) + + @dataclass + class ExtensionNotFound(Exception): + extension_id: ExtensionId + + @dataclass + class ExtensionExists(Exception): + extension_id: ExtensionId + + def add_extension(self, extension: Extension) -> Extension: + if extension.name in self.extensions: + raise self.ExtensionExists(extension.name) + # TODO version updates + self.extensions[extension.name] = extension + return self.extensions[extension.name] + + def get_extension(self, name: ExtensionId) -> Extension: + try: + return self.extensions[name] + except KeyError as e: + raise self.ExtensionNotFound(name) from e + @dataclass class Package: # noqa: D101 diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index a01f0c516..3d0c1817d 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -302,6 +302,15 @@ def check_id(self, extension: tys.ExtensionId, name: str) -> bool: """Check if the operation matches the given extension and operation name.""" return self.extension == extension and self.name == name + def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp: + """Resolve the custom operation to an :class:`ExtOp`.""" + op_def = registry.get_extension(self.extension).get_op(self.name) + + signature = self.signature.resolve(registry) + args = [arg.resolve(registry) for arg in self.args] + # TODO check signature matches op_def reported signature + return ExtOp(op_def, signature, args) + @dataclass(frozen=True) class ExtOp(AsCustomOp): diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py index 76c62b78a..327a91e08 100644 --- a/hugr-py/src/hugr/std/float.py +++ b/hugr-py/src/hugr/std/float.py @@ -6,9 +6,9 @@ from hugr import ext, tys, val -Extension = ext.Extension("arithmetic.float.types", ext.Version(0, 1, 0)) +EXTENSION = ext.Extension("arithmetic.float.types", ext.Version(0, 1, 0)) #: HUGR 64-bit IEEE 754-2019 floating point type. -FLOAT_T_DEF = Extension.add_type_def( +FLOAT_T_DEF = EXTENSION.add_type_def( ext.TypeDef( name="float64", description="64-bit IEEE 754-2019 floating point number", diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 222e1fcfc..58b3fe337 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -9,7 +9,8 @@ from hugr.utils import ser_it if TYPE_CHECKING: - from hugr.ext import TypeDef + from hugr import ext + ExtensionId = stys.ExtensionId ExtensionSet = stys.ExtensionSet @@ -37,6 +38,10 @@ def to_serial(self) -> stys.BaseTypeArg: def to_serial_root(self) -> stys.TypeArg: return stys.TypeArg(root=self.to_serial()) # type: ignore[arg-type] + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + """Resolve types in the argument using the given registry.""" + return self + @runtime_checkable class Type(Protocol): @@ -69,6 +74,10 @@ def type_arg(self) -> TypeTypeArg: """ return TypeTypeArg(self) + def resolve(self, registry: ext.ExtensionRegistry) -> Type: + """Resolve types in the type using the given registry.""" + return self + #: Row of types. TypeRow = list[Type] @@ -148,6 +157,9 @@ class TypeTypeArg(TypeArg): def to_serial(self) -> stys.TypeTypeArg: return stys.TypeTypeArg(ty=self.ty.to_serial_root()) + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return TypeTypeArg(self.ty.resolve(registry)) + @dataclass(frozen=True) class BoundedNatArg(TypeArg): @@ -178,6 +190,9 @@ class SequenceArg(TypeArg): def to_serial(self) -> stys.SequenceArg: return stys.SequenceArg(elems=ser_it(self.elems)) + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return SequenceArg([arg.resolve(registry) for arg in self.elems]) + @dataclass(frozen=True) class ExtensionsArg(TypeArg): @@ -387,6 +402,14 @@ def flip(self) -> FunctionType: def __repr__(self) -> str: return f"FunctionType({self.input}, {self.output})" + def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType: + """Resolve types in the function type using the given registry.""" + return FunctionType( + input=[ty.resolve(registry) for ty in self.input], + output=[ty.resolve(registry) for ty in self.output], + extension_reqs=self.extension_reqs, + ) + @dataclass(frozen=True) class PolyFuncType(Type): @@ -405,12 +428,19 @@ def to_serial(self) -> stys.PolyFuncType: params=[p.to_serial_root() for p in self.params], body=self.body.to_serial() ) + def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType: + """Resolve types in the polymorphic function type using the given registry.""" + return PolyFuncType( + params=self.params, + body=self.body.resolve(registry), + ) + @dataclass class ExtType(Type): """Extension type, defined by a type definition and type arguments.""" - type_def: TypeDef + type_def: ext.TypeDef args: list[TypeArg] = field(default_factory=list) def type_bound(self) -> TypeBound: @@ -458,6 +488,12 @@ def to_serial(self) -> stys.Opaque: def type_bound(self) -> TypeBound: return self.bound + def resolve(self, registry: ext.ExtensionRegistry) -> ExtType: + """Resolve the opaque type to an :class:`ExtType` using the given registry.""" + type_def = registry.get_extension(self.extension).get_type(self.id) + + return ExtType(type_def, self.args) + @dataclass class _QubitDef(Type): diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index a27aba171..da460fd31 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -2,13 +2,27 @@ import pytest -from hugr import tys +from hugr import ext, ops, tys from hugr.dfg import Dfg +from hugr.node_port import Node from hugr.ops import AsCustomOp, Custom -from hugr.std.int import DivMod -from hugr.std.logic import EXTENSION, Not +from hugr.std.float import EXTENSION as FLOAT_EXT +from hugr.std.int import OPS_EXTENSION, TYPES_EXTENSION, DivMod +from hugr.std.logic import EXTENSION as LOGIC_EXT +from hugr.std.logic import Not from .conftest import CX, H, Measure, Rz, validate +from .conftest import EXTENSION as QUANTUM_EXT + +STRINGLY_EXT = ext.Extension("my_extension", ext.Version(0, 0, 0)) +STRINGLY_EXT.add_op_def( + ext.OpDef( + "StringlyOp", + signature=ext.OpDefSig( + tys.PolyFuncType([tys.StringParam()], tys.FunctionType.endo([])) + ), + ) +) @dataclass @@ -16,12 +30,11 @@ class StringlyOp(AsCustomOp): tag: str def to_custom(self) -> Custom: - return Custom( - "StringlyOp", - extension="my_extension", - signature=tys.FunctionType.endo([]), - args=[tys.StringArg(self.tag)], - ) + return ops.ExtOp( + STRINGLY_EXT.get_op("StringlyOp"), + tys.FunctionType.endo([]), + [tys.StringArg(self.tag)], + ).to_custom() @classmethod def from_custom(cls, custom: Custom) -> "StringlyOp": @@ -45,19 +58,44 @@ def test_stringly_typed(): validate(dfg.hugr) +def test_registry(): + reg = ext.ExtensionRegistry() + reg.add_extension(LOGIC_EXT) + assert reg.get_extension(LOGIC_EXT.name).name == LOGIC_EXT.name + assert len(reg.extensions) == 1 + with pytest.raises(ext.ExtensionRegistry.ExtensionExists): + reg.add_extension(LOGIC_EXT) + + with pytest.raises(ext.ExtensionRegistry.ExtensionNotFound): + reg.get_extension("not_found") + + +@pytest.fixture() +def registry() -> ext.ExtensionRegistry: + reg = ext.ExtensionRegistry() + reg.add_extension(LOGIC_EXT) + reg.add_extension(QUANTUM_EXT) + reg.add_extension(STRINGLY_EXT) + reg.add_extension(TYPES_EXTENSION) + reg.add_extension(OPS_EXTENSION) + reg.add_extension(FLOAT_EXT) + + return reg + + @pytest.mark.parametrize( "as_custom", [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], ) -def test_custom(as_custom: AsCustomOp): +def test_custom(as_custom: AsCustomOp, registry: ext.ExtensionRegistry): custom = as_custom.to_custom() assert custom.to_custom() == custom assert Custom.from_custom(custom) == custom assert type(as_custom).from_custom(custom) == as_custom - # TODO extension resolution needed for this equality - # assert as_custom.to_serial(Node(0)).deserialize() == custom + # ExtOp compared to Custom via `to_custom` + assert as_custom.to_serial(Node(0)).deserialize().resolve(registry) == custom assert custom == as_custom assert as_custom == custom @@ -65,13 +103,13 @@ def test_custom(as_custom: AsCustomOp): def test_custom_bad_eq(): assert Not != DivMod - bad_custom_sig = Custom("Not", extension=EXTENSION.name) # empty signature + bad_custom_sig = Custom("Not", extension=LOGIC_EXT.name) # empty signature assert Not != bad_custom_sig bad_custom_args = Custom( "Not", - extension=EXTENSION.name, + extension=LOGIC_EXT.name, signature=tys.FunctionType.endo([tys.Bool]), args=[tys.Bool.type_arg()], ) From f888285a955642c74da248179aef8af9c4f82031 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 14:34:10 +0100 Subject: [PATCH 09/26] resolve extensions in a hugr --- hugr-py/src/hugr/ext.py | 1 + hugr-py/src/hugr/hugr.py | 13 ++++++++++++- hugr-py/tests/test_custom.py | 13 ++++++++++++- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 4be1341bd..84303e08c 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -86,6 +86,7 @@ def to_serial(self) -> ext_s.FixedHugr: return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr) +@dataclass class OpDefSig: # noqa: D101 poly_func: tys.PolyFuncType | None binary: bool diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 5918fed5b..3a2cb7a56 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -24,7 +24,7 @@ ToNode, _SubPort, ) -from hugr.ops import Call, Const, DataflowOp, Module, Op +from hugr.ops import Call, Const, Custom, DataflowOp, Module, Op from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.tys import Kind, Type, ValueKind @@ -34,6 +34,7 @@ from .exceptions import ParentBeforeChild if TYPE_CHECKING: + from hugr import ext from hugr.val import Value @@ -598,6 +599,16 @@ def _constrain_offset(self, p: P) -> PortOffset: return offset + def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr: + """Resolve extension types and operations in the HUGR by matching them to + extensions in the registry. + """ + for node in self: + op = self[node].op + if isinstance(op, Custom): + self[node].op = op.resolve(registry) + return self + @classmethod def from_serial(cls, serial: SerialHugr) -> Hugr: """Load a HUGR from a serialized form.""" diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index da460fd31..187d74ce7 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -4,8 +4,9 @@ from hugr import ext, ops, tys from hugr.dfg import Dfg +from hugr.hugr import Hugr from hugr.node_port import Node -from hugr.ops import AsCustomOp, Custom +from hugr.ops import AsCustomOp, Custom, ExtOp from hugr.std.float import EXTENSION as FLOAT_EXT from hugr.std.int import OPS_EXTENSION, TYPES_EXTENSION, DivMod from hugr.std.logic import EXTENSION as LOGIC_EXT @@ -57,6 +58,16 @@ def test_stringly_typed(): assert dfg.hugr[n].op == StringlyOp("world") validate(dfg.hugr) + new_h = Hugr.from_serial(dfg.hugr.to_serial()) + + assert isinstance(new_h[n].op, Custom) + + registry = ext.ExtensionRegistry() + registry.add_extension(STRINGLY_EXT) + new_h.resolve_extensions(registry) + + assert isinstance(new_h[n].op, ExtOp) + def test_registry(): reg = ext.ExtensionRegistry() From 186cc3d4c6c13f1bc426d807124a00e480bb9a63 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 15:14:26 +0100 Subject: [PATCH 10/26] replace `AsCustomOp` with `AsExtOp` --- hugr-py/src/hugr/ops.py | 85 ++++++++++++++++++++--------------- hugr-py/src/hugr/std/int.py | 14 +++--- hugr-py/src/hugr/std/logic.py | 8 ++-- hugr-py/tests/conftest.py | 38 ++++++++-------- hugr-py/tests/test_custom.py | 39 ++++++++-------- 5 files changed, 101 insertions(+), 83 deletions(-) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 3d0c1817d..036ac85a6 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -201,79 +201,78 @@ def _set_in_types(self, types: tys.TypeRow) -> None: @runtime_checkable -class AsCustomOp(DataflowOp, Protocol): +class AsExtOp(DataflowOp, Protocol): """Abstract interface that types can implement - to behave as a custom dataflow operation. + to behave as an extension dataflow operation. """ @dataclass(frozen=True) - class InvalidCustomOp(Exception): - """Custom operation does not match the expected type.""" + class InvalidExtOp(Exception): + """Extension operation does not match the expected type.""" msg: str @cached_property - def custom_op(self) -> Custom: - """:class:`Custom` operation that this type represents. + def ext_op(self) -> ExtOp: + """:class:`ExtOp` operation that this type represents. - Computed once using :meth:`to_custom` and cached - should be deterministic. + Computed once using :meth:`to_ext` and cached - should be deterministic. """ - return self.to_custom() + return self.to_ext() - def to_custom(self) -> Custom: - """Convert this type to a :class:`Custom` operation. + def to_ext(self) -> ExtOp: + """Convert this type to a :class:`ExtOp` operation. - Used by :attr:`custom_op`, so must be deterministic. + Used by :attr:`ext_op`, so must be deterministic. """ ... # pragma: no cover @classmethod - def from_custom(cls, custom: Custom) -> Self | None: - """Load from a :class:`Custom` operation. + def from_ext(cls, ext_op: ExtOp) -> Self | None: + """Load from a :class:`ExtOp` operation. By default assumes the type of `cls` is a singleton, - and compares the result of :meth:`to_custom` with the given `custom`. + and compares the result of :meth:`to_ext` with the given `ext_op`. If successful, returns the singleton, else None. Non-singleton types should override this method. Raises: - InvalidCustomOp: If the given `custom` does not match the expected one for a + InvalidCustomOp: If the given `ext_op` does not match the expected one for a given extension/operation name. """ default = cls() - if default.custom_op == custom: + if default.ext_op == ext_op: return default return None def __eq__(self, other: object) -> bool: - if not isinstance(other, AsCustomOp): + if not isinstance(other, AsExtOp): return NotImplemented - slf, other = self.custom_op, other.custom_op + slf, other = self.ext_op, other.ext_op return ( - slf.extension == other.extension - and slf.name == other.name - and slf.signature == other.signature + slf.op_def == other.op_def + and slf.outer_signature() == other.outer_signature() and slf.args == other.args ) def outer_signature(self) -> tys.FunctionType: - return self.custom_op.signature + return self.ext_op.outer_signature() def to_serial(self, parent: Node) -> sops.CustomOp: - return self.custom_op.to_serial(parent) + return self.ext_op.to_serial(parent) @property def num_out(self) -> int: - return len(self.custom_op.signature.output) + return len(self.outer_signature().output) @dataclass(frozen=True, eq=False) -class Custom(AsCustomOp): - """A non-core dataflow operation defined in an extension.""" +class Custom(DataflowOp): + """Serialisable version of non-core dataflow operation defined in an extension.""" name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @@ -291,12 +290,12 @@ def to_serial(self, parent: Node) -> sops.CustomOp: args=ser_it(self.args), ) - def to_custom(self) -> Custom: - return self + def outer_signature(self) -> tys.FunctionType: + return self.signature - @classmethod - def from_custom(cls, custom: Custom) -> Custom: - return custom + @property + def num_out(self) -> int: + return len(self.outer_signature().output) def check_id(self, extension: tys.ExtensionId, name: str) -> bool: """Check if the operation matches the given extension and operation name.""" @@ -312,15 +311,15 @@ def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp: return ExtOp(op_def, signature, args) -@dataclass(frozen=True) -class ExtOp(AsCustomOp): +@dataclass(frozen=True, eq=False) +class ExtOp(AsExtOp): """A non-core dataflow operation defined in an extension.""" op_def: ext.OpDef signature: tys.FunctionType | None = None args: list[tys.TypeArg] = field(default_factory=list) - def to_custom(self) -> Custom: + def to_custom_op(self) -> Custom: ext = self.op_def._extension if self.signature is None: poly_func = self.op_def.signature.poly_func @@ -339,7 +338,23 @@ def to_custom(self) -> Custom: ) def to_serial(self, parent: Node) -> sops.CustomOp: - return self.to_custom().to_serial(parent) + return self.to_custom_op().to_serial(parent) + + def to_ext(self) -> ExtOp: + return self + + @classmethod + def from_ext(cls, custom: ExtOp) -> ExtOp: + return custom + + def outer_signature(self) -> tys.FunctionType: + if self.signature is not None: + return self.signature + poly_func = self.op_def.signature.poly_func + if poly_func is None: + msg = "Polymorphic signature must be cached." + raise ValueError(msg) + return poly_func.body @dataclass() diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index a1438df9c..2404bdd10 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -8,7 +8,7 @@ from typing_extensions import Self from hugr import ext, tys, val -from hugr.ops import AsCustomOp, Custom, DataflowOp, ExtOp +from hugr.ops import AsExtOp, DataflowOp, ExtOp if TYPE_CHECKING: from hugr.ops import Command, ComWire @@ -79,32 +79,32 @@ def to_value(self) -> val.Extension: @dataclass(frozen=True) -class _DivModDef(AsCustomOp): +class _DivModDef(AsExtOp): """DivMod operation, has two inputs and two outputs.""" arg1: int = 5 arg2: int = 5 op_def: ext.OpDef = field(default_factory=lambda: _DivMod, init=False) - def to_custom(self) -> Custom: + def to_ext(self) -> ExtOp: row: list[tys.Type] = [int_t(self.arg1), int_t(self.arg2)] ext_op = ExtOp( self.op_def, tys.FunctionType.endo(row), [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)], ) - return ext_op.to_custom() + return ext_op @classmethod - def from_custom(cls, custom: Custom) -> Self | None: - if not custom.check_id(OPS_EXTENSION.name, _DivMod.name): + def from_ext(cls, custom: ExtOp) -> Self | None: + if custom.op_def != _DivMod: return None match custom.args: case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: return cls(arg1=a1, arg2=a2) case _: msg = f"Invalid args: {custom.args}" - raise AsCustomOp.InvalidCustomOp(msg) + raise AsExtOp.InvalidExtOp(msg) def __call__(self, a: ComWire, b: ComWire) -> Command: return DataflowOp.__call__(self, a, b) diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py index cea3bd80d..2884b7713 100644 --- a/hugr-py/src/hugr/std/logic.py +++ b/hugr-py/src/hugr/std/logic.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from hugr import ext, tys -from hugr.ops import AsCustomOp, Command, Custom, DataflowOp, ExtOp +from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp if TYPE_CHECKING: from hugr.ops import ComWire @@ -24,9 +24,9 @@ @dataclass(frozen=True) -class _NotOp(AsCustomOp): - def to_custom(self) -> Custom: - return ExtOp(_NotDef).to_custom() +class _NotOp(AsExtOp): + def to_ext(self) -> ExtOp: + return ExtOp(_NotDef) def __call__(self, a: ComWire) -> Command: return DataflowOp.__call__(self, a) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 5c3b40dfc..0386cb7ee 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -12,7 +12,7 @@ from hugr import ext, tys from hugr.hugr import Hugr -from hugr.ops import AsCustomOp, Command, Custom, DataflowOp, ExtOp +from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp from hugr.serialization.serial_hugr import SerialHugr from hugr.std.float import FLOAT_T @@ -60,14 +60,16 @@ E = TypeVar("E", bound=Enum) -def _load_enum(enum_cls: type[E], custom: Custom) -> E | None: - if custom.extension == EXTENSION.name and custom.name in enum_cls.__members__: - return enum_cls(custom.name) +def _load_enum(enum_cls: type[E], custom: ExtOp) -> E | None: + ext = custom.op_def._extension + assert ext is not None + if ext.name == EXTENSION.name and custom.op_def.name in enum_cls.__members__: + return enum_cls(custom.op_def.name) return None @dataclass(frozen=True) -class OneQbGate(AsCustomOp): +class OneQbGate(AsExtOp): # Have to nest enum to avoid meta class conflict class _Enum(Enum): H = "H" @@ -77,11 +79,11 @@ class _Enum(Enum): def __call__(self, q: ComWire) -> Command: return DataflowOp.__call__(self, q) - def to_custom(self) -> Custom: - return ExtOp(EXTENSION.operations[self._enum.value]).to_custom() + def to_ext(self) -> ExtOp: + return ExtOp(EXTENSION.operations[self._enum.value]) @classmethod - def from_custom(cls, custom: Custom) -> Self | None: + def from_ext(cls, custom: ExtOp) -> Self | None: return cls(e) if (e := _load_enum(cls._Enum, custom)) else None @@ -89,17 +91,17 @@ def from_custom(cls, custom: Custom) -> Self | None: @dataclass(frozen=True) -class TwoQbGate(AsCustomOp): +class TwoQbGate(AsExtOp): class _Enum(Enum): CX = "CX" _enum: _Enum - def to_custom(self) -> Custom: - return ExtOp(EXTENSION.operations[self._enum.value]).to_custom() + def to_ext(self) -> ExtOp: + return ExtOp(EXTENSION.operations[self._enum.value]) @classmethod - def from_custom(cls, custom: Custom) -> Self | None: + def from_ext(cls, custom: ExtOp) -> Self | None: return cls(e) if (e := _load_enum(cls._Enum, custom)) else None def __call__(self, q0: ComWire, q1: ComWire) -> Command: @@ -110,9 +112,9 @@ def __call__(self, q0: ComWire, q1: ComWire) -> Command: @dataclass(frozen=True) -class MeasureDef(AsCustomOp): - def to_custom(self) -> Custom: - return ExtOp(EXTENSION.operations["Measure"]).to_custom() +class MeasureDef(AsExtOp): + def to_ext(self) -> ExtOp: + return ExtOp(EXTENSION.operations["Measure"]) def __call__(self, q: ComWire) -> Command: return super().__call__(q) @@ -122,9 +124,9 @@ def __call__(self, q: ComWire) -> Command: @dataclass(frozen=True) -class RzDef(AsCustomOp): - def to_custom(self) -> Custom: - return ExtOp(EXTENSION.operations["Rz"]).to_custom() +class RzDef(AsExtOp): + def to_ext(self) -> ExtOp: + return ExtOp(EXTENSION.operations["Rz"]) def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: return super().__call__(q, fl_wire) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 187d74ce7..9dfd774e7 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -6,7 +6,7 @@ from hugr.dfg import Dfg from hugr.hugr import Hugr from hugr.node_port import Node -from hugr.ops import AsCustomOp, Custom, ExtOp +from hugr.ops import AsExtOp, Custom, ExtOp from hugr.std.float import EXTENSION as FLOAT_EXT from hugr.std.int import OPS_EXTENSION, TYPES_EXTENSION, DivMod from hugr.std.logic import EXTENSION as LOGIC_EXT @@ -16,7 +16,7 @@ from .conftest import EXTENSION as QUANTUM_EXT STRINGLY_EXT = ext.Extension("my_extension", ext.Version(0, 0, 0)) -STRINGLY_EXT.add_op_def( +_STRINGLY_DEF = STRINGLY_EXT.add_op_def( ext.OpDef( "StringlyOp", signature=ext.OpDefSig( @@ -27,28 +27,27 @@ @dataclass -class StringlyOp(AsCustomOp): +class StringlyOp(AsExtOp): tag: str - def to_custom(self) -> Custom: + def to_ext(self) -> ops.ExtOp: return ops.ExtOp( STRINGLY_EXT.get_op("StringlyOp"), tys.FunctionType.endo([]), [tys.StringArg(self.tag)], - ).to_custom() + ) @classmethod - def from_custom(cls, custom: Custom) -> "StringlyOp": + def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp": match custom: - case Custom( - name="StringlyOp", - extension="my_extension", + case ops.ExtOp( + op_def=_STRINGLY_DEF, args=[tys.StringArg(tag)], ): return cls(tag=tag) case _: msg = f"Invalid custom op: {custom}" - raise AsCustomOp.InvalidCustomOp(msg) + raise AsExtOp.InvalidExtOp(msg) def test_stringly_typed(): @@ -95,20 +94,22 @@ def registry() -> ext.ExtensionRegistry: @pytest.mark.parametrize( - "as_custom", + "as_ext", [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], ) -def test_custom(as_custom: AsCustomOp, registry: ext.ExtensionRegistry): - custom = as_custom.to_custom() +def test_custom(as_ext: AsExtOp, registry: ext.ExtensionRegistry): + ext_op = as_ext.to_ext() - assert custom.to_custom() == custom - assert Custom.from_custom(custom) == custom + assert ext_op.to_ext() == ext_op + assert ExtOp.from_ext(ext_op) == ext_op - assert type(as_custom).from_custom(custom) == as_custom + assert type(as_ext).from_ext(ext_op) == as_ext + custom = as_ext.to_serial(Node(0)).deserialize() + assert isinstance(custom, Custom) # ExtOp compared to Custom via `to_custom` - assert as_custom.to_serial(Node(0)).deserialize().resolve(registry) == custom - assert custom == as_custom - assert as_custom == custom + assert custom.resolve(registry) == ext_op + assert ext_op == as_ext + assert as_ext == ext_op def test_custom_bad_eq(): From d73faaa0d61c665e7d623a05bde86d453e80a636 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 15:26:01 +0100 Subject: [PATCH 11/26] update schema to include binary field in OpDef --- poetry.lock | 234 ++++++++++-------- specification/schema/hugr_schema_live.json | 5 + .../schema/hugr_schema_strict_live.json | 5 + .../schema/testing_hugr_schema_live.json | 5 + .../testing_hugr_schema_strict_live.json | 5 + 5 files changed, 148 insertions(+), 106 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0609a6784..fe505ef8a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -35,63 +35,83 @@ files = [ [[package]] name = "coverage" -version = "7.6.0" +version = "7.6.1" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dff044f661f59dace805eedb4a7404c573b6ff0cdba4a524141bc63d7be5c7fd"}, - {file = "coverage-7.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8659fd33ee9e6ca03950cfdcdf271d645cf681609153f218826dd9805ab585c"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7792f0ab20df8071d669d929c75c97fecfa6bcab82c10ee4adb91c7a54055463"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4b3cd1ca7cd73d229487fa5caca9e4bc1f0bca96526b922d61053ea751fe791"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7e128f85c0b419907d1f38e616c4f1e9f1d1b37a7949f44df9a73d5da5cd53c"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a94925102c89247530ae1dab7dc02c690942566f22e189cbd53579b0693c0783"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dcd070b5b585b50e6617e8972f3fbbee786afca71b1936ac06257f7e178f00f6"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d50a252b23b9b4dfeefc1f663c568a221092cbaded20a05a11665d0dbec9b8fb"}, - {file = "coverage-7.6.0-cp310-cp310-win32.whl", hash = "sha256:0e7b27d04131c46e6894f23a4ae186a6a2207209a05df5b6ad4caee6d54a222c"}, - {file = "coverage-7.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:54dece71673b3187c86226c3ca793c5f891f9fc3d8aa183f2e3653da18566169"}, - {file = "coverage-7.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7b525ab52ce18c57ae232ba6f7010297a87ced82a2383b1afd238849c1ff933"}, - {file = "coverage-7.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bea27c4269234e06f621f3fac3925f56ff34bc14521484b8f66a580aacc2e7d"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed8d1d1821ba5fc88d4a4f45387b65de52382fa3ef1f0115a4f7a20cdfab0e94"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c322ef2bbe15057bc4bf132b525b7e3f7206f071799eb8aa6ad1940bcf5fb1"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03cafe82c1b32b770a29fd6de923625ccac3185a54a5e66606da26d105f37dac"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0d1b923fc4a40c5832be4f35a5dab0e5ff89cddf83bb4174499e02ea089daf57"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4b03741e70fb811d1a9a1d75355cf391f274ed85847f4b78e35459899f57af4d"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a73d18625f6a8a1cbb11eadc1d03929f9510f4131879288e3f7922097a429f63"}, - {file = "coverage-7.6.0-cp311-cp311-win32.whl", hash = "sha256:65fa405b837060db569a61ec368b74688f429b32fa47a8929a7a2f9b47183713"}, - {file = "coverage-7.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:6379688fb4cfa921ae349c76eb1a9ab26b65f32b03d46bb0eed841fd4cb6afb1"}, - {file = "coverage-7.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f7db0b6ae1f96ae41afe626095149ecd1b212b424626175a6633c2999eaad45b"}, - {file = "coverage-7.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bbdf9a72403110a3bdae77948b8011f644571311c2fb35ee15f0f10a8fc082e8"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc44bf0315268e253bf563f3560e6c004efe38f76db03a1558274a6e04bf5d5"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da8549d17489cd52f85a9829d0e1d91059359b3c54a26f28bec2c5d369524807"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0086cd4fc71b7d485ac93ca4239c8f75732c2ae3ba83f6be1c9be59d9e2c6382"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fad32ee9b27350687035cb5fdf9145bc9cf0a094a9577d43e909948ebcfa27b"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:044a0985a4f25b335882b0966625270a8d9db3d3409ddc49a4eb00b0ef5e8cee"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:76d5f82213aa78098b9b964ea89de4617e70e0d43e97900c2778a50856dac605"}, - {file = "coverage-7.6.0-cp312-cp312-win32.whl", hash = "sha256:3c59105f8d58ce500f348c5b56163a4113a440dad6daa2294b5052a10db866da"}, - {file = "coverage-7.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:ca5d79cfdae420a1d52bf177de4bc2289c321d6c961ae321503b2ca59c17ae67"}, - {file = "coverage-7.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d39bd10f0ae453554798b125d2f39884290c480f56e8a02ba7a6ed552005243b"}, - {file = "coverage-7.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:beb08e8508e53a568811016e59f3234d29c2583f6b6e28572f0954a6b4f7e03d"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2e16f4cd2bc4d88ba30ca2d3bbf2f21f00f382cf4e1ce3b1ddc96c634bc48ca"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6616d1c9bf1e3faea78711ee42a8b972367d82ceae233ec0ac61cc7fec09fa6b"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4567d6c334c46046d1c4c20024de2a1c3abc626817ae21ae3da600f5779b44"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d17c6a415d68cfe1091d3296ba5749d3d8696e42c37fca5d4860c5bf7b729f03"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9146579352d7b5f6412735d0f203bbd8d00113a680b66565e205bc605ef81bc6"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:cdab02a0a941af190df8782aafc591ef3ad08824f97850b015c8c6a8b3877b0b"}, - {file = "coverage-7.6.0-cp38-cp38-win32.whl", hash = "sha256:df423f351b162a702c053d5dddc0fc0ef9a9e27ea3f449781ace5f906b664428"}, - {file = "coverage-7.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:f2501d60d7497fd55e391f423f965bbe9e650e9ffc3c627d5f0ac516026000b8"}, - {file = "coverage-7.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7221f9ac9dad9492cecab6f676b3eaf9185141539d5c9689d13fd6b0d7de840c"}, - {file = "coverage-7.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddaaa91bfc4477d2871442bbf30a125e8fe6b05da8a0015507bfbf4718228ab2"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cbe651f3904e28f3a55d6f371203049034b4ddbce65a54527a3f189ca3b390"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:831b476d79408ab6ccfadaaf199906c833f02fdb32c9ab907b1d4aa0713cfa3b"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c3d091059ad0b9c59d1034de74a7f36dcfa7f6d3bde782c49deb42438f2450"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4d5fae0a22dc86259dee66f2cc6c1d3e490c4a1214d7daa2a93d07491c5c04b6"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:07ed352205574aad067482e53dd606926afebcb5590653121063fbf4e2175166"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:49c76cdfa13015c4560702574bad67f0e15ca5a2872c6a125f6327ead2b731dd"}, - {file = "coverage-7.6.0-cp39-cp39-win32.whl", hash = "sha256:482855914928c8175735a2a59c8dc5806cf7d8f032e4820d52e845d1f731dca2"}, - {file = "coverage-7.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:543ef9179bc55edfd895154a51792b01c017c87af0ebaae092720152e19e42ca"}, - {file = "coverage-7.6.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:6fe885135c8a479d3e37a7aae61cbd3a0fb2deccb4dda3c25f92a49189f766d6"}, - {file = "coverage-7.6.0.tar.gz", hash = "sha256:289cc803fa1dc901f84701ac10c9ee873619320f2f9aff38794db4a4a0268d51"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, ] [package.dependencies] @@ -500,62 +520,64 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale [[package]] name = "pyyaml" -version = "6.0.1" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9d131adf6..52ca83ea5 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -1168,6 +1168,11 @@ ], "default": null }, + "binary": { + "default": false, + "title": "Binary", + "type": "boolean" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 957493f26..d7b7cc74e 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -1168,6 +1168,11 @@ ], "default": null }, + "binary": { + "default": false, + "title": "Binary", + "type": "boolean" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index c340dc1bd..4b70d5c86 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -1168,6 +1168,11 @@ ], "default": null }, + "binary": { + "default": false, + "title": "Binary", + "type": "boolean" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 6e1e6f903..df0d3a0fe 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -1168,6 +1168,11 @@ ], "default": null }, + "binary": { + "default": false, + "title": "Binary", + "type": "boolean" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" From 173897f9524942a87fe63e1f29e7707078abd353 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 16:17:37 +0100 Subject: [PATCH 12/26] break up AsExtOp - mainly return opdef --- hugr-py/src/hugr/ops.py | 46 ++++++++++++++++++++++++++--------- hugr-py/src/hugr/std/int.py | 20 +++++++-------- hugr-py/src/hugr/std/logic.py | 6 ++--- hugr-py/tests/conftest.py | 22 ++++++++--------- hugr-py/tests/test_custom.py | 19 ++++++++------- 5 files changed, 68 insertions(+), 45 deletions(-) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 036ac85a6..366e81762 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -216,18 +216,34 @@ class InvalidExtOp(Exception): def ext_op(self) -> ExtOp: """:class:`ExtOp` operation that this type represents. - Computed once using :meth:`to_ext` and cached - should be deterministic. + Computed once using :meth:`op_def` :meth:`type_args` and :meth:`type_args`. + Each of those methods should be deterministic. """ - return self.to_ext() + return ExtOp(self.op_def(), self.cached_signature(), self.type_args()) - def to_ext(self) -> ExtOp: - """Convert this type to a :class:`ExtOp` operation. + def op_def(self) -> ext.OpDef: + """The :class:`tys.OpDef` for this operation. Used by :attr:`ext_op`, so must be deterministic. """ ... # pragma: no cover + def type_args(self) -> list[tys.TypeArg]: + """Type arguments of the operation. + + Used by :attr:`op_def`, so must be deterministic. + """ + return [] + + def cached_signature(self) -> tys.FunctionType | None: + """Cached signature of the operation, if there is one. + + + Used by :attr:`op_def`, so must be deterministic. + """ + return None + @classmethod def from_ext(cls, ext_op: ExtOp) -> Self | None: """Load from a :class:`ExtOp` operation. @@ -254,7 +270,7 @@ def __eq__(self, other: object) -> bool: return NotImplemented slf, other = self.ext_op, other.ext_op return ( - slf.op_def == other.op_def + slf._op_def == other._op_def and slf.outer_signature() == other.outer_signature() and slf.args == other.args ) @@ -315,14 +331,14 @@ def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp: class ExtOp(AsExtOp): """A non-core dataflow operation defined in an extension.""" - op_def: ext.OpDef + _op_def: ext.OpDef signature: tys.FunctionType | None = None args: list[tys.TypeArg] = field(default_factory=list) def to_custom_op(self) -> Custom: - ext = self.op_def._extension + ext = self._op_def._extension if self.signature is None: - poly_func = self.op_def.signature.poly_func + poly_func = self._op_def.signature.poly_func if poly_func is None or len(poly_func.params) > 0: msg = "For polymorphic ops signature must be cached." raise ValueError(msg) @@ -331,7 +347,7 @@ def to_custom_op(self) -> Custom: sig = self.signature return Custom( - name=self.op_def.name, + name=self._op_def.name, signature=sig, extension=ext.name if ext else "", args=self.args, @@ -340,8 +356,14 @@ def to_custom_op(self) -> Custom: def to_serial(self, parent: Node) -> sops.CustomOp: return self.to_custom_op().to_serial(parent) - def to_ext(self) -> ExtOp: - return self + def op_def(self) -> ext.OpDef: + return self._op_def + + def type_args(self) -> list[tys.TypeArg]: + return self.args + + def cached_signature(self) -> tys.FunctionType | None: + return self.signature @classmethod def from_ext(cls, custom: ExtOp) -> ExtOp: @@ -350,7 +372,7 @@ def from_ext(cls, custom: ExtOp) -> ExtOp: def outer_signature(self) -> tys.FunctionType: if self.signature is not None: return self.signature - poly_func = self.op_def.signature.poly_func + poly_func = self._op_def.signature.poly_func if poly_func is None: msg = "Polymorphic signature must be cached." raise ValueError(msg) diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 2404bdd10..19b4b3132 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING from typing_extensions import Self @@ -84,20 +84,20 @@ class _DivModDef(AsExtOp): arg1: int = 5 arg2: int = 5 - op_def: ext.OpDef = field(default_factory=lambda: _DivMod, init=False) - def to_ext(self) -> ExtOp: + def op_def(self) -> ext.OpDef: + return _DivMod + + def type_args(self) -> list[tys.TypeArg]: + return [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)] + + def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.arg1), int_t(self.arg2)] - ext_op = ExtOp( - self.op_def, - tys.FunctionType.endo(row), - [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)], - ) - return ext_op + return tys.FunctionType.endo(row) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: - if custom.op_def != _DivMod: + if custom.op_def() != _DivMod: return None match custom.args: case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py index 2884b7713..baef0d761 100644 --- a/hugr-py/src/hugr/std/logic.py +++ b/hugr-py/src/hugr/std/logic.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from hugr import ext, tys -from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp +from hugr.ops import AsExtOp, Command, DataflowOp if TYPE_CHECKING: from hugr.ops import ComWire @@ -25,8 +25,8 @@ @dataclass(frozen=True) class _NotOp(AsExtOp): - def to_ext(self) -> ExtOp: - return ExtOp(_NotDef) + def op_def(self) -> ext.OpDef: + return _NotDef def __call__(self, a: ComWire) -> Command: return DataflowOp.__call__(self, a) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 0386cb7ee..1a1b9b0a5 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -61,10 +61,10 @@ def _load_enum(enum_cls: type[E], custom: ExtOp) -> E | None: - ext = custom.op_def._extension + ext = custom._op_def._extension assert ext is not None - if ext.name == EXTENSION.name and custom.op_def.name in enum_cls.__members__: - return enum_cls(custom.op_def.name) + if ext.name == EXTENSION.name and custom._op_def.name in enum_cls.__members__: + return enum_cls(custom._op_def.name) return None @@ -79,8 +79,8 @@ class _Enum(Enum): def __call__(self, q: ComWire) -> Command: return DataflowOp.__call__(self, q) - def to_ext(self) -> ExtOp: - return ExtOp(EXTENSION.operations[self._enum.value]) + def op_def(self) -> ext.OpDef: + return EXTENSION.operations[self._enum.value] @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: @@ -97,8 +97,8 @@ class _Enum(Enum): _enum: _Enum - def to_ext(self) -> ExtOp: - return ExtOp(EXTENSION.operations[self._enum.value]) + def op_def(self) -> ext.OpDef: + return EXTENSION.operations[self._enum.value] @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: @@ -113,8 +113,8 @@ def __call__(self, q0: ComWire, q1: ComWire) -> Command: @dataclass(frozen=True) class MeasureDef(AsExtOp): - def to_ext(self) -> ExtOp: - return ExtOp(EXTENSION.operations["Measure"]) + def op_def(self) -> ext.OpDef: + return EXTENSION.operations["Measure"] def __call__(self, q: ComWire) -> Command: return super().__call__(q) @@ -125,8 +125,8 @@ def __call__(self, q: ComWire) -> Command: @dataclass(frozen=True) class RzDef(AsExtOp): - def to_ext(self) -> ExtOp: - return ExtOp(EXTENSION.operations["Rz"]) + def op_def(self) -> ext.OpDef: + return EXTENSION.operations["Rz"] def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: return super().__call__(q, fl_wire) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 9dfd774e7..10d621b72 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -30,18 +30,20 @@ class StringlyOp(AsExtOp): tag: str - def to_ext(self) -> ops.ExtOp: - return ops.ExtOp( - STRINGLY_EXT.get_op("StringlyOp"), - tys.FunctionType.endo([]), - [tys.StringArg(self.tag)], - ) + def op_def(self) -> ext.OpDef: + return STRINGLY_EXT.get_op("StringlyOp") + + def type_args(self) -> list[tys.TypeArg]: + return [tys.StringArg(self.tag)] + + def cached_signature(self) -> tys.FunctionType | None: + return tys.FunctionType.endo([]) @classmethod def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp": match custom: case ops.ExtOp( - op_def=_STRINGLY_DEF, + _op_def=_STRINGLY_DEF, args=[tys.StringArg(tag)], ): return cls(tag=tag) @@ -98,9 +100,8 @@ def registry() -> ext.ExtensionRegistry: [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], ) def test_custom(as_ext: AsExtOp, registry: ext.ExtensionRegistry): - ext_op = as_ext.to_ext() + ext_op = as_ext.ext_op - assert ext_op.to_ext() == ext_op assert ExtOp.from_ext(ext_op) == ext_op assert type(as_ext).from_ext(ext_op) == as_ext From 5c4d3f795543e23b2a89e6a870580edd89c9dc8f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 16:19:24 +0100 Subject: [PATCH 13/26] inline some signatures --- hugr-py/tests/conftest.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 1a1b9b0a5..2c3a32423 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -20,16 +20,11 @@ from hugr.ops import ComWire EXTENSION = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0)) -_SINGLE_QUBIT = ext.OpDefSig(tys.FunctionType.endo([tys.Qubit])) -_TWO_QUBIT = ext.OpDefSig(tys.FunctionType.endo([tys.Qubit] * 2)) -_MEAS_SIG = ext.OpDefSig(tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])) -_RZ_SIG = ext.OpDefSig(tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])) - EXTENSION.add_op_def( ext.OpDef( name="H", description="Hadamard gate", - signature=_SINGLE_QUBIT, + signature=ext.OpDefSig(tys.FunctionType.endo([tys.Qubit])), ) ) @@ -37,7 +32,7 @@ ext.OpDef( name="CX", description="CNOT gate", - signature=_TWO_QUBIT, + signature=ext.OpDefSig(tys.FunctionType.endo([tys.Qubit] * 2)), ) ) @@ -45,7 +40,7 @@ ext.OpDef( name="Measure", description="Measurement operation", - signature=_MEAS_SIG, + signature=ext.OpDefSig(tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])), ) ) @@ -53,7 +48,7 @@ ext.OpDef( name="Rz", description="Rotation around the z-axis", - signature=_RZ_SIG, + signature=ext.OpDefSig(tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])), ) ) From 04231044531c26d95cb5277ddf227135cd17b17a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 17:20:01 +0100 Subject: [PATCH 14/26] decorator for easily registering operations --- hugr-py/src/hugr/ext.py | 48 ++++++++++++++++++++++++++++++++--- hugr-py/src/hugr/ops.py | 16 +++++++++++- hugr-py/src/hugr/std/int.py | 26 +++++++------------ hugr-py/src/hugr/std/logic.py | 19 +++++--------- hugr-py/tests/conftest.py | 35 ++++++++----------------- 5 files changed, 87 insertions(+), 57 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 84303e08c..6d1c93573 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -3,12 +3,12 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from semver import Version import hugr.serialization.extension as ext_s -from hugr import tys, val +from hugr import ops, tys, val from hugr.utils import ser_it __all__ = [ @@ -25,7 +25,7 @@ ] if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from hugr.hugr import Hugr from hugr.tys import ExtensionId @@ -151,6 +151,9 @@ def to_serial(self) -> ext_s.ExtensionValue: ) +T = TypeVar("T", bound=ops.RegisteredOp) + + @dataclass class Extension: # noqa: D101 name: ExtensionId @@ -221,6 +224,45 @@ def get_value(self, name: str) -> ExtensionValue: except KeyError as e: raise self.ValueNotFound(name) from e + T = TypeVar("T", bound=ops.RegisteredOp) + + def register_op( + self, + name: str | None = None, + signature: OpDefSig | tys.PolyFuncType | tys.FunctionType | None = None, + description: str | None = None, + misc: dict[str, Any] | None = None, + lower_funcs: list[FixedHugr] | None = None, + ) -> Callable[[type[T]], type[T]]: + """Register a class as corresponding to an operation definition. + + If `name` is not provided, the class name is used. + If `signature` is not provided, a binary signature is assumed. + If `description` is not provided, the class docstring is used. + + See :class:`OpDef` for other parameters. + """ + if not isinstance(signature, OpDefSig): + binary = signature is None + signature = OpDefSig(signature, binary) + + def _inner(cls: type[T]) -> type[T]: + new_description = cls.__doc__ if description is None and cls.__doc__ else "" + new_name = cls.__name__ if name is None else name + op_def = self.add_op_def( + OpDef( + new_name, + signature, + new_description, + misc or {}, + lower_funcs or [], + ) + ) + cls.const_op_def = op_def + return cls + + return _inner + @dataclass class ExtensionRegistry: diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 366e81762..c1f3d4dae 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -9,13 +9,14 @@ from typing_extensions import Self import hugr.serialization.ops as sops -from hugr import ext, tys, val +from hugr import tys, val from hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire from hugr.utils import ser_it if TYPE_CHECKING: from collections.abc import Sequence + from hugr import ext from hugr.serialization.ops import BaseOp @@ -379,6 +380,19 @@ def outer_signature(self) -> tys.FunctionType: return poly_func.body +class RegisteredOp(AsExtOp): + """Base class for operations that are registered with an extension using + :meth:`Extension.register_op `. + """ + + #: Known operation definition. + const_op_def: ext.OpDef # must be initialised by register_op + + def op_def(self) -> ext.OpDef: + # override for AsExtOp.op_def + return self.const_op_def + + @dataclass() class MakeTuple(DataflowOp, _PartialOp): """Operation to create a tuple from a sequence of wires.""" diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 19b4b3132..77f4342c8 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -8,7 +8,7 @@ from typing_extensions import Self from hugr import ext, tys, val -from hugr.ops import AsExtOp, DataflowOp, ExtOp +from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp if TYPE_CHECKING: from hugr.ops import Command, ComWire @@ -67,27 +67,19 @@ def to_value(self) -> val.Extension: OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0)) -_DivMod = OPS_EXTENSION.add_op_def( - ext.OpDef( - name="idivmod_u", - description="Unsigned integer division and modulo.", - signature=ext.OpDefSig( - tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)]) - ), - ) -) - +@OPS_EXTENSION.register_op( + signature=ext.OpDefSig( + tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)]) + ), +) @dataclass(frozen=True) -class _DivModDef(AsExtOp): +class idivmod_u(RegisteredOp): """DivMod operation, has two inputs and two outputs.""" arg1: int = 5 arg2: int = 5 - def op_def(self) -> ext.OpDef: - return _DivMod - def type_args(self) -> list[tys.TypeArg]: return [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)] @@ -97,7 +89,7 @@ def cached_signature(self) -> tys.FunctionType | None: @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: - if custom.op_def() != _DivMod: + if custom.op_def() != cls.const_op_def: return None match custom.args: case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: @@ -111,4 +103,4 @@ def __call__(self, a: ComWire, b: ComWire) -> Command: #: DivMod operation. -DivMod = _DivModDef() +DivMod = idivmod_u() diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py index baef0d761..d104a64b8 100644 --- a/hugr-py/src/hugr/std/logic.py +++ b/hugr-py/src/hugr/std/logic.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from hugr import ext, tys -from hugr.ops import AsExtOp, Command, DataflowOp +from hugr.ops import Command, DataflowOp, RegisteredOp if TYPE_CHECKING: from hugr.ops import ComWire @@ -14,19 +14,14 @@ EXTENSION = ext.Extension("logic", ext.Version(0, 1, 0)) -_NotDef = EXTENSION.add_op_def( - ext.OpDef( - name="Not", - description="Logical NOT operation.", - signature=ext.OpDefSig(tys.FunctionType.endo([tys.Bool])), - ) -) - +@EXTENSION.register_op( + name="Not", + signature=ext.OpDefSig(tys.FunctionType.endo([tys.Bool])), +) @dataclass(frozen=True) -class _NotOp(AsExtOp): - def op_def(self) -> ext.OpDef: - return _NotDef +class _NotOp(RegisteredOp): + """Logical NOT operation.""" def __call__(self, a: ComWire) -> Command: return DataflowOp.__call__(self, a) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 2c3a32423..d7e469a18 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -12,7 +12,7 @@ from hugr import ext, tys from hugr.hugr import Hugr -from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp +from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp, RegisteredOp from hugr.serialization.serial_hugr import SerialHugr from hugr.std.float import FLOAT_T @@ -36,21 +36,6 @@ ) ) -EXTENSION.add_op_def( - ext.OpDef( - name="Measure", - description="Measurement operation", - signature=ext.OpDefSig(tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])), - ) -) - -EXTENSION.add_op_def( - ext.OpDef( - name="Rz", - description="Rotation around the z-axis", - signature=ext.OpDefSig(tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])), - ) -) E = TypeVar("E", bound=Enum) @@ -106,11 +91,12 @@ def __call__(self, q0: ComWire, q1: ComWire) -> Command: CX = TwoQbGate(TwoQbGate._Enum.CX) +@EXTENSION.register_op( + "Measure", + signature=tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]), +) @dataclass(frozen=True) -class MeasureDef(AsExtOp): - def op_def(self) -> ext.OpDef: - return EXTENSION.operations["Measure"] - +class MeasureDef(RegisteredOp): def __call__(self, q: ComWire) -> Command: return super().__call__(q) @@ -118,11 +104,12 @@ def __call__(self, q: ComWire) -> Command: Measure = MeasureDef() +@EXTENSION.register_op( + "Rz", + signature=tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]), +) @dataclass(frozen=True) -class RzDef(AsExtOp): - def op_def(self) -> ext.OpDef: - return EXTENSION.operations["Rz"] - +class RzDef(RegisteredOp): def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: return super().__call__(q, fl_wire) From b44e87037f850508497a20f5e0ee469b75af8cf5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 17:22:04 +0100 Subject: [PATCH 15/26] fix merge --- hugr-py/src/hugr/std/int.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 77f4342c8..b2c32407e 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING from typing_extensions import Self From ea2d2252fdbeee786aa9f765cb4b3a46a0c3f10a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 12 Aug 2024 17:32:19 +0100 Subject: [PATCH 16/26] add missing resolve to sum --- hugr-py/src/hugr/tys.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 58b3fe337..a17020801 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -258,6 +258,10 @@ def __repr__(self) -> str: def type_bound(self) -> TypeBound: return TypeBound.join(*(t.type_bound() for r in self.variant_rows for t in r)) + def resolve(self, registry: ext.ExtensionRegistry) -> Sum: + """Resolve types in the sum type using the given registry.""" + return Sum([[ty.resolve(registry) for ty in row] for row in self.variant_rows]) + @dataclass() class UnitSum(Sum): @@ -279,6 +283,9 @@ def __repr__(self) -> str: return "Unit" return f"UnitSum({self.size})" + def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum: + return self + @dataclass() class Tuple(Sum): From af4f7b821d1a75398633f324903993b6aca26fea Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 11:53:04 +0100 Subject: [PATCH 17/26] add docstrings to ext.py --- hugr-py/src/hugr/ext.py | 172 ++++++++++++++++++-- hugr-py/src/hugr/serialization/extension.py | 2 +- 2 files changed, 162 insertions(+), 12 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 6d1c93573..835e9a939 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -32,7 +32,15 @@ @dataclass -class ExplicitBound: # noqa: D101 +class ExplicitBound: + """An explicit type bound on an :class:`OpDef`. + + + Examples: + >>> ExplicitBound(tys.TypeBound.Copyable) + ExplicitBound(bound=) + """ + bound: tys.TypeBound def to_serial(self) -> ext_s.ExplicitBound: @@ -43,7 +51,16 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class FromParamsBound: # noqa: D101 +class FromParamsBound: + """Calculate the type bound of an :class:`OpDef` from the join of its parameters at + the given indices. + + + Examples: + >>> FromParamsBound(indices=[0, 1]) + FromParamsBound(indices=[0, 1]) + """ + indices: list[int] def to_serial(self) -> ext_s.FromParamsBound: @@ -54,10 +71,28 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class TypeDef: # noqa: D101 +class TypeDef: + """Type definition in an :class:`Extension`. + + + Examples: + >>> td = TypeDef( + ... name="MyType", + ... description="A type definition.", + ... params=[tys.TypeTypeParam(tys.Bool)], + ... bound=ExplicitBound(tys.TypeBound.Copyable), + ... ) + >>> td.name + 'MyType' + """ + + #: The name of the type. name: str + #: A description of the type. description: str + #: The type parameters of the type if polymorphic. params: list[tys.TypeParam] + #: The type bound of the type. bound: ExplicitBound | FromParamsBound _extension: Extension | None = field( default=None, init=False, repr=False, compare=False @@ -74,12 +109,21 @@ def to_serial(self) -> ext_s.TypeDef: ) def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType: + """Instantiate a concrete type from this type definition. + + Args: + args: Type arguments corresponding to the type parameters of the definition. + """ return tys.ExtType(self, list(args)) @dataclass -class FixedHugr: # noqa: D101 +class FixedHugr: + """A HUGR used to define lowerings of operations in an :class:`OpDef`.""" + + #: Extensions used in the HUGR. extensions: tys.ExtensionSet + #: HUGR defining operation lowering. hugr: Hugr def to_serial(self) -> ext_s.FixedHugr: @@ -87,8 +131,12 @@ def to_serial(self) -> ext_s.FixedHugr: @dataclass -class OpDefSig: # noqa: D101 +class OpDefSig: + """Type signature of an :class:`OpDef`.""" + + #: The polymorphic function type of the operation (type scheme). poly_func: tys.PolyFuncType | None + #: If no static type scheme known, flag indidcates a computation of the signature. binary: bool def __init__( @@ -109,11 +157,18 @@ def __init__( @dataclass -class OpDef: # noqa: D101 +class OpDef: + """Operation definition in an :class:`Extension`.""" + + #: The name of the operation. name: str + #: The type signature of the operation. signature: OpDefSig + #: A description of the operation. description: str = "" + #: Miscellaneous information about the operation. misc: dict[str, Any] = field(default_factory=dict) + #: Lowerings of the operation. lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False) _extension: Extension | None = field( default=None, init=False, repr=False, compare=False @@ -135,9 +190,13 @@ def to_serial(self) -> ext_s.OpDef: @dataclass -class ExtensionValue: # noqa: D101 +class ExtensionValue: + """A value defined in an :class:`Extension`.""" + + #: The name of the value. name: str - typed_value: val.Value + #: Value payload. + val: val.Value _extension: Extension | None = field( default=None, init=False, repr=False, compare=False ) @@ -147,7 +206,7 @@ def to_serial(self) -> ext_s.ExtensionValue: return ext_s.ExtensionValue( extension=self._extension.name, name=self.name, - typed_value=self.typed_value.to_serial_root(), + typed_value=self.val.to_serial_root(), ) @@ -155,12 +214,20 @@ def to_serial(self) -> ext_s.ExtensionValue: @dataclass -class Extension: # noqa: D101 +class Extension: + """HUGR extension declaration.""" + + #: The name of the extension. name: ExtensionId + #: The version of the extension. version: Version + #: Extensions required by this extension, identified by name. extension_reqs: set[ExtensionId] = field(default_factory=set) + #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) + #: Values defined in the extension. values: dict[str, ExtensionValue] = field(default_factory=dict) + #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @dataclass @@ -180,16 +247,40 @@ def to_serial(self) -> ext_s.Extension: ) def add_op_def(self, op_def: OpDef) -> OpDef: + """Add an operation definition to the extension. + + Args: + op_def: The operation definition to add. + + Returns: + The added operation definition, now associated with the extension. + """ op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] def add_type_def(self, type_def: TypeDef) -> TypeDef: + """Add a type definition to the extension. + + Args: + type_def: The type definition to add. + + Returns: + The added type definition, now associated with the extension. + """ type_def._extension = self self.types[type_def.name] = type_def return self.types[type_def.name] def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: + """Add a value to the extension. + + Args: + extension_value: The value to add. + + Returns: + The added value, now associated with the extension. + """ extension_value._extension = self self.values[extension_value.name] = extension_value return self.values[extension_value.name] @@ -199,6 +290,17 @@ class OperationNotFound(NotFound): """Operation not found in extension.""" def get_op(self, name: str) -> OpDef: + """Retrieve an operation definition by name. + + Args: + name: The name of the operation. + + Returns: + The operation definition. + + Raises: + OperationNotFound: If the operation is not found in the extension. + """ try: return self.operations[name] except KeyError as e: @@ -209,6 +311,17 @@ class TypeNotFound(NotFound): """Type not found in extension.""" def get_type(self, name: str) -> TypeDef: + """Retrieve a type definition by name. + + Args: + name: The name of the type. + + Returns: + The type definition. + + Raises: + TypeNotFound: If the type is not found in the extension. + """ try: return self.types[name] except KeyError as e: @@ -266,17 +379,35 @@ def _inner(cls: type[T]) -> type[T]: @dataclass class ExtensionRegistry: + """Registry of extensions.""" + + #: Extensions in the registry, indexed by name. extensions: dict[ExtensionId, Extension] = field(default_factory=dict) @dataclass class ExtensionNotFound(Exception): + """Extension not found in registry.""" + extension_id: ExtensionId @dataclass class ExtensionExists(Exception): + """Extension already exists in registry.""" + extension_id: ExtensionId def add_extension(self, extension: Extension) -> Extension: + """Add an extension to the registry. + + Args: + extension: The extension to add. + + Returns: + The added extension. + + Raises: + ExtensionExists: If an extension with the same name already exists. + """ if extension.name in self.extensions: raise self.ExtensionExists(extension.name) # TODO version updates @@ -284,6 +415,17 @@ def add_extension(self, extension: Extension) -> Extension: return self.extensions[extension.name] def get_extension(self, name: ExtensionId) -> Extension: + """Retrieve an extension by name. + + Args: + name: The name of the extension. + + Returns: + Extension in the registry. + + Raises: + ExtensionNotFound: If the extension is not found in the registry. + """ try: return self.extensions[name] except KeyError as e: @@ -291,8 +433,16 @@ def get_extension(self, name: ExtensionId) -> Extension: @dataclass -class Package: # noqa: D101 +class Package: + """A package of HUGR modules and extensions. + + + The HUGRs may refer to the included extensions or those not included. + """ + + #: HUGR modules in the package. modules: list[Hugr] + #: Extensions included in the package. extensions: list[Extension] = field(default_factory=list) def to_serial(self) -> ext_s.Package: diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/extension.py index ff90dfa80..25b0943d6 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/extension.py @@ -70,7 +70,7 @@ def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: return extension.add_extension_value( ext.ExtensionValue( name=self.name, - typed_value=self.typed_value.deserialize(), + val=self.typed_value.deserialize(), ) ) From da2b1c061e8aeea0e32d0f36be3d4e3a4d076d64 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 12:21:31 +0100 Subject: [PATCH 18/26] refactor: common up parent extension field --- hugr-py/src/hugr/ext.py | 57 ++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 835e9a939..a05005272 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -71,7 +71,40 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class TypeDef: +class NoParentExtension(Exception): + """Parent extension must be set.""" + + kind: str + + def __str__(self): + return f"{self.kind} does not belong to an extension." + + +@dataclass(init=False) +class ExtensionObject: + """An object associated with an :class:`Extension`.""" + + _extension: Extension | None = field( + default=None, init=False, repr=False, compare=False + ) + + def get_extension(self) -> Extension: + """Retrieve the extension associated with the object. + + Returns: + The extension associated with the object. + + Raises: + NoParentExtension: If the object is not associated with an extension. + """ + if self._extension is None: + msg = self.__class__.__name__ + raise NoParentExtension(msg) + return self._extension + + +@dataclass +class TypeDef(ExtensionObject): """Type definition in an :class:`Extension`. @@ -94,14 +127,10 @@ class TypeDef: params: list[tys.TypeParam] #: The type bound of the type. bound: ExplicitBound | FromParamsBound - _extension: Extension | None = field( - default=None, init=False, repr=False, compare=False - ) def to_serial(self) -> ext_s.TypeDef: - assert self._extension is not None, "Extension must be initialised." return ext_s.TypeDef( - extension=self._extension.name, + extension=self.get_extension().name, name=self.name, description=self.description, params=ser_it(self.params), @@ -157,7 +186,7 @@ def __init__( @dataclass -class OpDef: +class OpDef(ExtensionObject): """Operation definition in an :class:`Extension`.""" #: The name of the operation. @@ -170,14 +199,10 @@ class OpDef: misc: dict[str, Any] = field(default_factory=dict) #: Lowerings of the operation. lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False) - _extension: Extension | None = field( - default=None, init=False, repr=False, compare=False - ) def to_serial(self) -> ext_s.OpDef: - assert self._extension is not None, "Extension must be initialised." return ext_s.OpDef( - extension=self._extension.name, + extension=self.get_extension().name, name=self.name, description=self.description, misc=self.misc, @@ -190,21 +215,17 @@ def to_serial(self) -> ext_s.OpDef: @dataclass -class ExtensionValue: +class ExtensionValue(ExtensionObject): """A value defined in an :class:`Extension`.""" #: The name of the value. name: str #: Value payload. val: val.Value - _extension: Extension | None = field( - default=None, init=False, repr=False, compare=False - ) def to_serial(self) -> ext_s.ExtensionValue: - assert self._extension is not None, "Extension must be initialised." return ext_s.ExtensionValue( - extension=self._extension.name, + extension=self.get_extension().name, name=self.name, typed_value=self.val.to_serial_root(), ) From 599078e30d9cc970153ee6102518b0489837b320 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 12:21:50 +0100 Subject: [PATCH 19/26] remove incorrect todo --- hugr-py/src/hugr/val.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 886893d00..7f0a8f249 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -199,6 +199,3 @@ def type_(self) -> tys.Type: def to_serial(self) -> sops.ExtensionValue: return self.to_value().to_serial() - - -# TODO extension value that points to an extension. From 8ed9d6a77ab89a873e4f14fcebd85c59d10a02c7 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 12:30:28 +0100 Subject: [PATCH 20/26] avoid resolve errors when extension not found --- hugr-py/src/hugr/ops.py | 17 ++++++++++++++--- hugr-py/src/hugr/tys.py | 14 +++++++++++--- hugr-py/tests/test_custom.py | 5 +++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index c1f3d4dae..af0ce359c 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -318,9 +318,20 @@ def check_id(self, extension: tys.ExtensionId, name: str) -> bool: """Check if the operation matches the given extension and operation name.""" return self.extension == extension and self.name == name - def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp: - """Resolve the custom operation to an :class:`ExtOp`.""" - op_def = registry.get_extension(self.extension).get_op(self.name) + def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp | Custom: + """Resolve the custom operation to an :class:`ExtOp`. + + If extension or operation is not, returns itself. + """ + from hugr.ext import ExtensionRegistry, Extension # noqa: I001 # no circular import + + try: + op_def = registry.get_extension(self.extension).get_op(self.name) + except ( + Extension.OperationNotFound, + ExtensionRegistry.ExtensionNotFound, + ): + return self signature = self.signature.resolve(registry) args = [arg.resolve(registry) for arg in self.args] diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index a17020801..63c5eda63 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -495,9 +495,17 @@ def to_serial(self) -> stys.Opaque: def type_bound(self) -> TypeBound: return self.bound - def resolve(self, registry: ext.ExtensionRegistry) -> ExtType: - """Resolve the opaque type to an :class:`ExtType` using the given registry.""" - type_def = registry.get_extension(self.extension).get_type(self.id) + def resolve(self, registry: ext.ExtensionRegistry) -> Type: + """Resolve the opaque type to an :class:`ExtType` using the given registry. + + If the extension or type is not found, return the original type. + """ + from hugr.ext import ExtensionRegistry, Extension # noqa: I001 # no circular import + + try: + type_def = registry.get_extension(self.extension).get_type(self.id) + except (ExtensionRegistry.ExtensionNotFound, Extension.TypeNotFound): + return self return ExtType(type_def, self.args) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 10d621b72..4a65bffda 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -64,6 +64,11 @@ def test_stringly_typed(): assert isinstance(new_h[n].op, Custom) registry = ext.ExtensionRegistry() + new_h.resolve_extensions(registry) + + # doesn't resolve without extension + assert isinstance(new_h[n].op, Custom) + registry.add_extension(STRINGLY_EXT) new_h.resolve_extensions(registry) From 5704eaecab0c7c2b9a266ae0afca791006477119 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 15:20:36 +0100 Subject: [PATCH 21/26] Apply suggestions from code review Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> --- hugr-py/src/hugr/ext.py | 2 +- hugr-py/src/hugr/ops.py | 4 ++-- hugr-py/src/hugr/std/int.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index a05005272..03e9e1544 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -165,7 +165,7 @@ class OpDefSig: #: The polymorphic function type of the operation (type scheme). poly_func: tys.PolyFuncType | None - #: If no static type scheme known, flag indidcates a computation of the signature. + #: If no static type scheme known, flag indicates a computation of the signature. binary: bool def __init__( diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index af0ce359c..042493a1f 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -289,7 +289,7 @@ def num_out(self) -> int: @dataclass(frozen=True, eq=False) class Custom(DataflowOp): - """Serialisable version of non-core dataflow operation defined in an extension.""" + """Serializable version of non-core dataflow operation defined in an extension.""" name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @@ -321,7 +321,7 @@ def check_id(self, extension: tys.ExtensionId, name: str) -> bool: def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp | Custom: """Resolve the custom operation to an :class:`ExtOp`. - If extension or operation is not, returns itself. + If extension or operation is not found, returns itself. """ from hugr.ext import ExtensionRegistry, Extension # noqa: I001 # no circular import diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index b2c32407e..9367a154e 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -18,7 +18,7 @@ INT_T_DEF = TYPES_EXTENSION.add_type_def( ext.TypeDef( name="int", - description="Variable width integer.", + description="Variable-width integer.", bound=ext.ExplicitBound(tys.TypeBound.Copyable), params=[_INT_PARAM], ) From bdeb400d730f14526f91c9e5dff8118714290ebb Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 17:40:00 +0100 Subject: [PATCH 22/26] clarify int extension names --- hugr-py/src/hugr/std/int.py | 8 ++++---- hugr-py/tests/test_custom.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 9367a154e..fbf21958d 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -13,9 +13,9 @@ if TYPE_CHECKING: from hugr.ops import Command, ComWire -TYPES_EXTENSION = ext.Extension("arithmetic.int.types", ext.Version(0, 1, 0)) +INT_TYPES_EXTENSION = ext.Extension("arithmetic.int.types", ext.Version(0, 1, 0)) _INT_PARAM = tys.BoundedNatParam(7) -INT_T_DEF = TYPES_EXTENSION.add_type_def( +INT_T_DEF = INT_TYPES_EXTENSION.add_type_def( ext.TypeDef( name="int", description="Variable-width integer.", @@ -65,10 +65,10 @@ def to_value(self) -> val.Extension: return val.Extension("int", int_t(self.width), self.v) -OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0)) +INT_OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0)) -@OPS_EXTENSION.register_op( +@INT_OPS_EXTENSION.register_op( signature=ext.OpDefSig( tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)]) ), diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 4a65bffda..5fb6695e7 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -8,7 +8,7 @@ from hugr.node_port import Node from hugr.ops import AsExtOp, Custom, ExtOp from hugr.std.float import EXTENSION as FLOAT_EXT -from hugr.std.int import OPS_EXTENSION, TYPES_EXTENSION, DivMod +from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod from hugr.std.logic import EXTENSION as LOGIC_EXT from hugr.std.logic import Not @@ -93,8 +93,8 @@ def registry() -> ext.ExtensionRegistry: reg.add_extension(LOGIC_EXT) reg.add_extension(QUANTUM_EXT) reg.add_extension(STRINGLY_EXT) - reg.add_extension(TYPES_EXTENSION) - reg.add_extension(OPS_EXTENSION) + reg.add_extension(INT_TYPES_EXTENSION) + reg.add_extension(INT_OPS_EXTENSION) reg.add_extension(FLOAT_EXT) return reg From 9ce6eb6c5d7a18ce3c56a41b6799300d0f5107ea Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 13 Aug 2024 18:02:35 +0100 Subject: [PATCH 23/26] test type resolution --- hugr-py/tests/test_custom.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 5fb6695e7..e7ea5c2ae 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -8,7 +8,8 @@ from hugr.node_port import Node from hugr.ops import AsExtOp, Custom, ExtOp from hugr.std.float import EXTENSION as FLOAT_EXT -from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod +from hugr.std.float import FLOAT_T +from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t from hugr.std.logic import EXTENSION as LOGIC_EXT from hugr.std.logic import Not @@ -104,7 +105,7 @@ def registry() -> ext.ExtensionRegistry: "as_ext", [Not, DivMod, H, CX, Measure, Rz, StringlyOp("hello")], ) -def test_custom(as_ext: AsExtOp, registry: ext.ExtensionRegistry): +def test_custom_op(as_ext: AsExtOp, registry: ext.ExtensionRegistry): ext_op = as_ext.ext_op assert ExtOp.from_ext(ext_op) == ext_op @@ -133,3 +134,21 @@ def test_custom_bad_eq(): ) assert Not != bad_custom_args + + +@pytest.mark.parametrize( + "ext_t", + [FLOAT_T, int_t(5)], +) +def test_custom_type(ext_t: tys.ExtType, registry: ext.ExtensionRegistry): + opaque = ext_t.to_serial().deserialize() + assert isinstance(opaque, tys.Opaque) + assert opaque.resolve(registry) == ext_t + + assert opaque.resolve(ext.ExtensionRegistry()) == opaque + + f_t = tys.FunctionType.endo([ext_t]) + f_t_opaque = f_t.to_serial().deserialize() + assert isinstance(f_t_opaque.input[0], tys.Opaque) + + assert f_t_opaque.resolve(registry) == f_t From 00dea37945bb99d90b201522a5e07c9aa97313bd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 14 Aug 2024 09:32:17 +0100 Subject: [PATCH 24/26] add test for from params type def --- hugr-py/tests/test_custom.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index e7ea5c2ae..a4d22cdd2 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -136,9 +136,21 @@ def test_custom_bad_eq(): assert Not != bad_custom_args +_LIST_T = STRINGLY_EXT.add_type_def( + ext.TypeDef( + "List", + description="A list of elements.", + params=[tys.TypeTypeParam(tys.TypeBound.Any)], + bound=ext.FromParamsBound([0]), + ) +) + +_BOOL_LIST_T = _LIST_T.instantiate([tys.Bool.type_arg()]) + + @pytest.mark.parametrize( "ext_t", - [FLOAT_T, int_t(5)], + [FLOAT_T, int_t(5), _BOOL_LIST_T], ) def test_custom_type(ext_t: tys.ExtType, registry: ext.ExtensionRegistry): opaque = ext_t.to_serial().deserialize() From 712968b39c8e7ff06789e36cd64147bf0527019b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 14 Aug 2024 11:44:23 +0100 Subject: [PATCH 25/26] assert key matches object name --- hugr-py/src/hugr/serialization/extension.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/extension.py index 25b0943d6..7ce2ef481 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/extension.py @@ -133,13 +133,16 @@ def deserialize(self) -> ext.Extension: extension_reqs=self.extension_reqs, ) - for t in self.types.values(): + for k, t in self.types.items(): + assert k == t.name, "Type name must match key" e.add_type_def(t.deserialize(e)) - for o in self.operations.values(): + for k, o in self.operations.items(): + assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for v in self.values.values(): + for k, v in self.values.items(): + assert k == v.name, "Value name must match key" e.add_extension_value(v.deserialize(e)) return e From 1873cc2974b2ba41ef941545dad9ca96ba39c4cd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 14 Aug 2024 11:45:46 +0100 Subject: [PATCH 26/26] clarify/update TODOs --- hugr-py/src/hugr/ext.py | 1 - hugr-py/src/hugr/ops.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 03e9e1544..1fb2844d3 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -431,7 +431,6 @@ def add_extension(self, extension: Extension) -> Extension: """ if extension.name in self.extensions: raise self.ExtensionExists(extension.name) - # TODO version updates self.extensions[extension.name] = extension return self.extensions[extension.name] diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 042493a1f..3bf5eb41f 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -336,6 +336,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp | Custom: signature = self.signature.resolve(registry) args = [arg.resolve(registry) for arg in self.args] # TODO check signature matches op_def reported signature + # if/once op_def can compute signature from type scheme + args return ExtOp(op_def, signature, args)