Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(hugr-py)!: make serialization (module/methods) private #1477

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class BaseValue(ABC, ConfiguredBaseModel):
def deserialize(self) -> val.Value: ...


class ExtensionValue(BaseValue):
class CustomValue(BaseValue):
"""An extension constant value, that can check it is of a given [CustomType]."""

v: Literal["Extension"] = Field(default="Extension", title="ValueTag")
Expand All @@ -127,11 +127,11 @@ class FunctionValue(BaseValue):
hugr: Any

def deserialize(self) -> val.Value:
from hugr._serialization.serial_hugr import SerialHugr
from hugr.hugr import Hugr
from hugr.serialization.serial_hugr import SerialHugr

# pydantic stores the serialized dictionary because of the "Any" annotation
return val.Function(Hugr.from_serial(SerialHugr(**self.hugr)))
return val.Function(Hugr._from_serial(SerialHugr(**self.hugr)))


class TupleValue(BaseValue):
Expand Down Expand Up @@ -172,9 +172,7 @@ def deserialize(self) -> val.Value:
class Value(RootModel):
"""A constant Value."""

root: ExtensionValue | FunctionValue | TupleValue | SumValue = Field(
discriminator="v"
)
root: CustomValue | FunctionValue | TupleValue | SumValue = Field(discriminator="v")

model_config = ConfigDict(json_schema_extra={"required": ["v"]})

Expand Down Expand Up @@ -501,7 +499,7 @@ def deserialize(self) -> ops.CFG:
ControlFlowOp = Conditional | TailLoop | CFG


class Extension(DataflowOp):
class ExtensionOp(DataflowOp):
"""A user-defined operation that can be downcasted by the extensions that define
it.
"""
Expand Down Expand Up @@ -649,7 +647,7 @@ class OpType(RootModel):
| CallIndirect
| LoadConstant
| LoadFunction
| Extension
| ExtensionOp
| Noop
| MakeTuple
| UnpackTuple
Expand Down
46 changes: 23 additions & 23 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from semver import Version

import hugr.serialization.extension as ext_s
import hugr._serialization.extension as ext_s
from hugr import ops, tys, val
from hugr.utils import ser_it

Expand Down Expand Up @@ -43,11 +43,11 @@

bound: tys.TypeBound

def to_serial(self) -> ext_s.ExplicitBound:
def _to_serial(self) -> ext_s.ExplicitBound:
return ext_s.ExplicitBound(bound=self.bound)

def to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self.to_serial())
def _to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self._to_serial())

Check warning on line 50 in hugr-py/src/hugr/ext.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ext.py#L50

Added line #L50 was not covered by tests


@dataclass
Expand All @@ -63,11 +63,11 @@

indices: list[int]

def to_serial(self) -> ext_s.FromParamsBound:
def _to_serial(self) -> ext_s.FromParamsBound:
return ext_s.FromParamsBound(indices=self.indices)

def to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self.to_serial())
def _to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self._to_serial())

Check warning on line 70 in hugr-py/src/hugr/ext.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ext.py#L70

Added line #L70 was not covered by tests


@dataclass
Expand Down Expand Up @@ -128,13 +128,13 @@
#: The type bound of the type.
bound: ExplicitBound | FromParamsBound

def to_serial(self) -> ext_s.TypeDef:
def _to_serial(self) -> ext_s.TypeDef:
return ext_s.TypeDef(
extension=self.get_extension().name,
name=self.name,
description=self.description,
params=ser_it(self.params),
bound=ext_s.TypeDefBound(root=self.bound.to_serial()),
bound=ext_s.TypeDefBound(root=self.bound._to_serial()),
)

def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType:
Expand All @@ -155,7 +155,7 @@
#: HUGR defining operation lowering.
hugr: Hugr

def to_serial(self) -> ext_s.FixedHugr:
def _to_serial(self) -> ext_s.FixedHugr:
return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr)


Expand Down Expand Up @@ -200,17 +200,17 @@
#: Lowerings of the operation.
lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False)

def to_serial(self) -> ext_s.OpDef:
def _to_serial(self) -> ext_s.OpDef:
return ext_s.OpDef(
extension=self.get_extension().name,
name=self.name,
description=self.description,
misc=self.misc,
signature=self.signature.poly_func.to_serial()
signature=self.signature.poly_func._to_serial()
if self.signature.poly_func
else None,
binary=self.signature.binary,
lower_funcs=[f.to_serial() for f in self.lower_funcs],
lower_funcs=[f._to_serial() for f in self.lower_funcs],
)


Expand All @@ -223,11 +223,11 @@
#: Value payload.
val: val.Value

def to_serial(self) -> ext_s.ExtensionValue:
def _to_serial(self) -> ext_s.ExtensionValue:
return ext_s.ExtensionValue(
extension=self.get_extension().name,
name=self.name,
typed_value=self.val.to_serial_root(),
typed_value=self.val._to_serial_root(),
)


Expand Down Expand Up @@ -257,14 +257,14 @@

name: str

def to_serial(self) -> ext_s.Extension:
def _to_serial(self) -> ext_s.Extension:
return ext_s.Extension(
name=self.name,
version=self.version, # type: ignore[arg-type]
extension_reqs=self.extension_reqs,
types={k: v.to_serial() for k, v in self.types.items()},
values={k: v.to_serial() for k, v in self.values.items()},
operations={k: v.to_serial() for k, v in self.operations.items()},
types={k: v._to_serial() for k, v in self.types.items()},
values={k: v._to_serial() for k, v in self.values.items()},
operations={k: v._to_serial() for k, v in self.operations.items()},
)

def add_op_def(self, op_def: OpDef) -> OpDef:
Expand Down Expand Up @@ -465,11 +465,11 @@
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def to_serial(self) -> ext_s.Package:
def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m.to_serial() for m in self.modules],
extensions=[e.to_serial() for e in self.extensions],
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self.to_serial().model_dump_json()
return self._to_serial().model_dump_json()
18 changes: 9 additions & 9 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
overload,
)

from hugr._serialization.ops import OpType as SerialOp
from hugr._serialization.serial_hugr import SerialHugr
from hugr.node_port import (
Direction,
InPort,
Expand All @@ -26,8 +28,6 @@
_SubPort,
)
from hugr.ops import Call, Const, Custom, DataflowOp, Module, Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.tys import Kind, Type, ValueKind
from hugr.utils import BiMap
from hugr.val import Value
Expand All @@ -54,8 +54,8 @@
children: list[Node] = field(default_factory=list, repr=False)
metadata: dict[str, Any] = field(default_factory=dict)

def to_serial(self, node: Node) -> SerialOp:
o = self.op.to_serial(self.parent if self.parent else node)
def _to_serial(self, node: Node) -> SerialOp:
o = self.op._to_serial(self.parent if self.parent else node)

return SerialOp(root=o) # type: ignore[arg-type]

Expand Down Expand Up @@ -601,7 +601,7 @@
)
return mapping

def to_serial(self) -> SerialHugr:
def _to_serial(self) -> SerialHugr:
"""Serialize the HUGR."""
node_it = (node for node in self._nodes if node is not None)

Expand All @@ -614,7 +614,7 @@

return SerialHugr(
# non contiguous indices will be erased
nodes=[node.to_serial(Node(idx, {})) for idx, node in enumerate(node_it)],
nodes=[node._to_serial(Node(idx, {})) for idx, node in enumerate(node_it)],
edges=[_serialize_link(link) for link in self._links.items()],
metadata=[node.metadata if node.metadata else None for node in node_it],
)
Expand Down Expand Up @@ -644,7 +644,7 @@
return self

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
def _from_serial(cls, serial: SerialHugr) -> Hugr:
"""Load a HUGR from a serialized form."""
assert serial.nodes, "Empty Hugr is invalid"

Expand Down Expand Up @@ -685,14 +685,14 @@

def to_json(self) -> str:
"""Serialize the HUGR to a JSON string."""
return self.to_serial().to_json()
return self._to_serial().to_json()

@classmethod
def load_json(cls, json_str: str) -> Hugr:
"""Deserialize a JSON string into a HUGR."""
json_dict = json.loads(json_str)
serial = SerialHugr.load_json(json_dict)
return cls.from_serial(serial)
return cls._from_serial(serial)

Check warning on line 695 in hugr-py/src/hugr/hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/hugr.py#L695

Added line #L695 was not covered by tests

def render_dot(self, palette: str | None = None) -> gv.Digraph:
"""Render the HUGR to a graphviz Digraph.
Expand Down
Loading
Loading