Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py)!: user facing Extension class #1413

Merged
merged 26 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
36a20e7
feat: basic user-facing dataclasses for extensions
ss2165 Aug 8, 2024
5ba757b
make extension optional in definition classes
ss2165 Aug 8, 2024
755a1a9
refactor(hugr-py): rename extension to ext
ss2165 Aug 9, 2024
59ef88d
separate resolved and unresolved types
ss2165 Aug 9, 2024
5afd4cc
point to extension in opdef and extval
ss2165 Aug 9, 2024
0b9b2a6
use extension objects in std extension defs
ss2165 Aug 9, 2024
73f9242
undo serialized extension rename
ss2165 Aug 9, 2024
844f86c
feat: resolve custom ops and types to extensions
ss2165 Aug 12, 2024
f888285
resolve extensions in a hugr
ss2165 Aug 12, 2024
186cc3d
replace `AsCustomOp` with `AsExtOp`
ss2165 Aug 12, 2024
d73faaa
update schema to include binary field in OpDef
ss2165 Aug 12, 2024
173897f
break up AsExtOp - mainly return opdef
ss2165 Aug 12, 2024
5c4d3f7
inline some signatures
ss2165 Aug 12, 2024
0423104
decorator for easily registering operations
ss2165 Aug 12, 2024
b44e870
fix merge
ss2165 Aug 12, 2024
ea2d225
add missing resolve to sum
ss2165 Aug 12, 2024
af4f7b8
add docstrings to ext.py
ss2165 Aug 13, 2024
da2b1c0
refactor: common up parent extension field
ss2165 Aug 13, 2024
599078e
remove incorrect todo
ss2165 Aug 13, 2024
8ed9d6a
avoid resolve errors when extension not found
ss2165 Aug 13, 2024
5704eae
Apply suggestions from code review
ss2165 Aug 13, 2024
bdeb400
clarify int extension names
ss2165 Aug 13, 2024
9ce6eb6
test type resolution
ss2165 Aug 13, 2024
00dea37
add test for from params type def
ss2165 Aug 14, 2024
712968b
assert key matches object name
ss2165 Aug 14, 2024
1873cc2
clarify/update TODOs
ss2165 Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
472 changes: 472 additions & 0 deletions hugr-py/src/hugr/ext.py

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ToNode,
_SubPort,
)
from hugr.ops import Call, Const, DataflowOp, Module, Op
from hugr.ops import Call, Const, Custom, DataflowOp, Module, Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.tys import Kind, Type, ValueKind
Expand All @@ -34,6 +34,7 @@
from .exceptions import ParentBeforeChild

if TYPE_CHECKING:
from hugr import ext

Check warning on line 37 in hugr-py/src/hugr/hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/hugr.py#L37

Added line #L37 was not covered by tests
from hugr.val import Value


Expand Down Expand Up @@ -598,6 +599,16 @@

return offset

def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
"""Resolve extension types and operations in the HUGR by matching them to
extensions in the registry.
"""
for node in self:
op = self[node].op
if isinstance(op, Custom):
self[node].op = op.resolve(registry)
return self

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
"""Load a HUGR from a serialized form."""
Expand Down
164 changes: 133 additions & 31 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from hugr import ext

Check warning on line 19 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L19

Added line #L19 was not covered by tests
from hugr.serialization.ops import BaseOp


Expand Down Expand Up @@ -201,79 +202,94 @@


@runtime_checkable
class AsCustomOp(DataflowOp, Protocol):
class AsExtOp(DataflowOp, Protocol):
"""Abstract interface that types can implement
to behave as a custom dataflow operation.
to behave as an extension dataflow operation.
"""

@dataclass(frozen=True)
class InvalidCustomOp(Exception):
"""Custom operation does not match the expected type."""
class InvalidExtOp(Exception):
"""Extension operation does not match the expected type."""

msg: str

@cached_property
def custom_op(self) -> Custom:
""":class:`Custom` operation that this type represents.
def ext_op(self) -> ExtOp:
""":class:`ExtOp` operation that this type represents.

Computed once using :meth:`to_custom` and cached - should be deterministic.
Computed once using :meth:`op_def` :meth:`type_args` and :meth:`type_args`.
Each of those methods should be deterministic.
"""
return self.to_custom()
return ExtOp(self.op_def(), self.cached_signature(), self.type_args())

def to_custom(self) -> Custom:
"""Convert this type to a :class:`Custom` operation.
def op_def(self) -> ext.OpDef:
"""The :class:`tys.OpDef` for this operation.


Used by :attr:`custom_op`, so must be deterministic.
Used by :attr:`ext_op`, so must be deterministic.
"""
... # pragma: no cover

def type_args(self) -> list[tys.TypeArg]:
"""Type arguments of the operation.

Used by :attr:`op_def`, so must be deterministic.
"""
return []

def cached_signature(self) -> tys.FunctionType | None:
"""Cached signature of the operation, if there is one.


Used by :attr:`op_def`, so must be deterministic.
"""
return None

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
"""Load from a :class:`Custom` operation.
def from_ext(cls, ext_op: ExtOp) -> Self | None:
"""Load from a :class:`ExtOp` operation.


By default assumes the type of `cls` is a singleton,
and compares the result of :meth:`to_custom` with the given `custom`.
and compares the result of :meth:`to_ext` with the given `ext_op`.

If successful, returns the singleton, else None.

Non-singleton types should override this method.

Raises:
InvalidCustomOp: If the given `custom` does not match the expected one for a
InvalidCustomOp: If the given `ext_op` does not match the expected one for a
given extension/operation name.
"""
default = cls()
if default.custom_op == custom:
if default.ext_op == ext_op:
return default
return None

def __eq__(self, other: object) -> bool:
if not isinstance(other, AsCustomOp):
if not isinstance(other, AsExtOp):
return NotImplemented
slf, other = self.custom_op, other.custom_op
slf, other = self.ext_op, other.ext_op
return (
slf.extension == other.extension
and slf.name == other.name
and slf.signature == other.signature
slf._op_def == other._op_def
and slf.outer_signature() == other.outer_signature()
and slf.args == other.args
)

def outer_signature(self) -> tys.FunctionType:
return self.custom_op.signature
return self.ext_op.outer_signature()

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.custom_op.to_serial(parent)
return self.ext_op.to_serial(parent)

@property
def num_out(self) -> int:
return len(self.custom_op.signature.output)
return len(self.outer_signature().output)


@dataclass(frozen=True, eq=False)
class Custom(AsCustomOp):
"""A non-core dataflow operation defined in an extension."""
class Custom(DataflowOp):
"""Serializable version of non-core dataflow operation defined in an extension."""

name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
Expand All @@ -291,17 +307,103 @@
args=ser_it(self.args),
)

def to_custom(self) -> Custom:
return self
def outer_signature(self) -> tys.FunctionType:
return self.signature

Check warning on line 311 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L311

Added line #L311 was not covered by tests

@classmethod
def from_custom(cls, custom: Custom) -> Custom:
return custom
@property
def num_out(self) -> int:
return len(self.outer_signature().output)

Check warning on line 315 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L315

Added line #L315 was not covered by tests

def check_id(self, extension: tys.ExtensionId, name: str) -> bool:
"""Check if the operation matches the given extension and operation name."""
return self.extension == extension and self.name == name

def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp | Custom:
"""Resolve the custom operation to an :class:`ExtOp`.

If extension or operation is not found, returns itself.
"""
from hugr.ext import ExtensionRegistry, Extension # noqa: I001 # no circular import

try:
op_def = registry.get_extension(self.extension).get_op(self.name)
except (
Extension.OperationNotFound,
ExtensionRegistry.ExtensionNotFound,
):
return self

signature = self.signature.resolve(registry)
args = [arg.resolve(registry) for arg in self.args]
# TODO check signature matches op_def reported signature
# if/once op_def can compute signature from type scheme + args
return ExtOp(op_def, signature, args)


@dataclass(frozen=True, eq=False)
class ExtOp(AsExtOp):
"""A non-core dataflow operation defined in an extension."""

_op_def: ext.OpDef
signature: tys.FunctionType | None = None
args: list[tys.TypeArg] = field(default_factory=list)

def to_custom_op(self) -> Custom:
ext = self._op_def._extension
if self.signature is None:
poly_func = self._op_def.signature.poly_func
if poly_func is None or len(poly_func.params) > 0:
msg = "For polymorphic ops signature must be cached."
raise ValueError(msg)

Check warning on line 357 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L356-L357

Added lines #L356 - L357 were not covered by tests
sig = poly_func.body
else:
sig = self.signature

return Custom(
name=self._op_def.name,
signature=sig,
extension=ext.name if ext else "",
args=self.args,
)

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.to_custom_op().to_serial(parent)

def op_def(self) -> ext.OpDef:
return self._op_def

def type_args(self) -> list[tys.TypeArg]:
return self.args

def cached_signature(self) -> tys.FunctionType | None:
return self.signature

@classmethod
def from_ext(cls, custom: ExtOp) -> ExtOp:
return custom

def outer_signature(self) -> tys.FunctionType:
if self.signature is not None:
return self.signature
poly_func = self._op_def.signature.poly_func
if poly_func is None:
msg = "Polymorphic signature must be cached."
raise ValueError(msg)

Check warning on line 391 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L390-L391

Added lines #L390 - L391 were not covered by tests
return poly_func.body


class RegisteredOp(AsExtOp):
"""Base class for operations that are registered with an extension using
:meth:`Extension.register_op <hugr.ext.Extension.register_op>`.
"""

#: Known operation definition.
const_op_def: ext.OpDef # must be initialised by register_op

def op_def(self) -> ext.OpDef:
# override for AsExtOp.op_def
return self.const_op_def


@dataclass()
class MakeTuple(DataflowOp, _PartialOp):
Expand Down
Loading
Loading