Skip to content

Commit

Permalink
add docstrings to ext.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Aug 13, 2024
1 parent ea2d225 commit af4f7b8
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 12 deletions.
172 changes: 161 additions & 11 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@


@dataclass
class ExplicitBound: # noqa: D101
class ExplicitBound:
"""An explicit type bound on an :class:`OpDef`.
Examples:
>>> ExplicitBound(tys.TypeBound.Copyable)
ExplicitBound(bound=<TypeBound.Copyable: 'C'>)
"""

bound: tys.TypeBound

def to_serial(self) -> ext_s.ExplicitBound:
Expand All @@ -43,7 +51,16 @@ def to_serial_root(self) -> ext_s.TypeDefBound:


@dataclass
class FromParamsBound: # noqa: D101
class FromParamsBound:
"""Calculate the type bound of an :class:`OpDef` from the join of its parameters at
the given indices.
Examples:
>>> FromParamsBound(indices=[0, 1])
FromParamsBound(indices=[0, 1])
"""

indices: list[int]

def to_serial(self) -> ext_s.FromParamsBound:
Expand All @@ -54,10 +71,28 @@ def to_serial_root(self) -> ext_s.TypeDefBound:


@dataclass
class TypeDef: # noqa: D101
class TypeDef:
"""Type definition in an :class:`Extension`.
Examples:
>>> td = TypeDef(
... name="MyType",
... description="A type definition.",
... params=[tys.TypeTypeParam(tys.Bool)],
... bound=ExplicitBound(tys.TypeBound.Copyable),
... )
>>> td.name
'MyType'
"""

#: The name of the type.
name: str
#: A description of the type.
description: str
#: The type parameters of the type if polymorphic.
params: list[tys.TypeParam]
#: The type bound of the type.
bound: ExplicitBound | FromParamsBound
_extension: Extension | None = field(
default=None, init=False, repr=False, compare=False
Expand All @@ -74,21 +109,34 @@ def to_serial(self) -> ext_s.TypeDef:
)

def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType:
"""Instantiate a concrete type from this type definition.
Args:
args: Type arguments corresponding to the type parameters of the definition.
"""
return tys.ExtType(self, list(args))


@dataclass
class FixedHugr: # noqa: D101
class FixedHugr:
"""A HUGR used to define lowerings of operations in an :class:`OpDef`."""

#: Extensions used in the HUGR.
extensions: tys.ExtensionSet
#: HUGR defining operation lowering.
hugr: Hugr

def to_serial(self) -> ext_s.FixedHugr:
return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr)

Check warning on line 130 in hugr-py/src/hugr/ext.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ext.py#L130

Added line #L130 was not covered by tests


@dataclass
class OpDefSig: # noqa: D101
class OpDefSig:
"""Type signature of an :class:`OpDef`."""

#: The polymorphic function type of the operation (type scheme).
poly_func: tys.PolyFuncType | None
#: If no static type scheme known, flag indidcates a computation of the signature.
binary: bool

def __init__(
Expand All @@ -109,11 +157,18 @@ def __init__(


@dataclass
class OpDef: # noqa: D101
class OpDef:
"""Operation definition in an :class:`Extension`."""

#: The name of the operation.
name: str
#: The type signature of the operation.
signature: OpDefSig
#: A description of the operation.
description: str = ""
#: Miscellaneous information about the operation.
misc: dict[str, Any] = field(default_factory=dict)
#: Lowerings of the operation.
lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False)
_extension: Extension | None = field(
default=None, init=False, repr=False, compare=False
Expand All @@ -135,9 +190,13 @@ def to_serial(self) -> ext_s.OpDef:


@dataclass
class ExtensionValue: # noqa: D101
class ExtensionValue:
"""A value defined in an :class:`Extension`."""

#: The name of the value.
name: str
typed_value: val.Value
#: Value payload.
val: val.Value
_extension: Extension | None = field(
default=None, init=False, repr=False, compare=False
)
Expand All @@ -147,20 +206,28 @@ def to_serial(self) -> ext_s.ExtensionValue:
return ext_s.ExtensionValue(

Check warning on line 206 in hugr-py/src/hugr/ext.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ext.py#L205-L206

Added lines #L205 - L206 were not covered by tests
extension=self._extension.name,
name=self.name,
typed_value=self.typed_value.to_serial_root(),
typed_value=self.val.to_serial_root(),
)


T = TypeVar("T", bound=ops.RegisteredOp)


@dataclass
class Extension: # noqa: D101
class Extension:
"""HUGR extension declaration."""

#: The name of the extension.
name: ExtensionId
#: The version of the extension.
version: Version
#: Extensions required by this extension, identified by name.
extension_reqs: set[ExtensionId] = field(default_factory=set)
#: Type definitions in the extension.
types: dict[str, TypeDef] = field(default_factory=dict)
#: Values defined in the extension.
values: dict[str, ExtensionValue] = field(default_factory=dict)
#: Operation definitions in the extension.
operations: dict[str, OpDef] = field(default_factory=dict)

@dataclass
Expand All @@ -180,16 +247,40 @@ def to_serial(self) -> ext_s.Extension:
)

def add_op_def(self, op_def: OpDef) -> OpDef:
"""Add an operation definition to the extension.
Args:
op_def: The operation definition to add.
Returns:
The added operation definition, now associated with the extension.
"""
op_def._extension = self
self.operations[op_def.name] = op_def
return self.operations[op_def.name]

def add_type_def(self, type_def: TypeDef) -> TypeDef:
"""Add a type definition to the extension.
Args:
type_def: The type definition to add.
Returns:
The added type definition, now associated with the extension.
"""
type_def._extension = self
self.types[type_def.name] = type_def
return self.types[type_def.name]

def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue:
"""Add a value to the extension.
Args:
extension_value: The value to add.
Returns:
The added value, now associated with the extension.
"""
extension_value._extension = self
self.values[extension_value.name] = extension_value
return self.values[extension_value.name]

Check warning on line 286 in hugr-py/src/hugr/ext.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ext.py#L284-L286

Added lines #L284 - L286 were not covered by tests
Expand All @@ -199,6 +290,17 @@ class OperationNotFound(NotFound):
"""Operation not found in extension."""

def get_op(self, name: str) -> OpDef:
"""Retrieve an operation definition by name.
Args:
name: The name of the operation.
Returns:
The operation definition.
Raises:
OperationNotFound: If the operation is not found in the extension.
"""
try:
return self.operations[name]
except KeyError as e:
Expand All @@ -209,6 +311,17 @@ class TypeNotFound(NotFound):
"""Type not found in extension."""

def get_type(self, name: str) -> TypeDef:
"""Retrieve a type definition by name.
Args:
name: The name of the type.
Returns:
The type definition.
Raises:
TypeNotFound: If the type is not found in the extension.
"""
try:
return self.types[name]
except KeyError as e:
Expand Down Expand Up @@ -266,33 +379,70 @@ def _inner(cls: type[T]) -> type[T]:

@dataclass
class ExtensionRegistry:
"""Registry of extensions."""

#: Extensions in the registry, indexed by name.
extensions: dict[ExtensionId, Extension] = field(default_factory=dict)

@dataclass
class ExtensionNotFound(Exception):
"""Extension not found in registry."""

extension_id: ExtensionId

@dataclass
class ExtensionExists(Exception):
"""Extension already exists in registry."""

extension_id: ExtensionId

def add_extension(self, extension: Extension) -> Extension:
"""Add an extension to the registry.
Args:
extension: The extension to add.
Returns:
The added extension.
Raises:
ExtensionExists: If an extension with the same name already exists.
"""
if extension.name in self.extensions:
raise self.ExtensionExists(extension.name)
# TODO version updates
self.extensions[extension.name] = extension
return self.extensions[extension.name]

def get_extension(self, name: ExtensionId) -> Extension:
"""Retrieve an extension by name.
Args:
name: The name of the extension.
Returns:
Extension in the registry.
Raises:
ExtensionNotFound: If the extension is not found in the registry.
"""
try:
return self.extensions[name]
except KeyError as e:
raise self.ExtensionNotFound(name) from e


@dataclass
class Package: # noqa: D101
class Package:
"""A package of HUGR modules and extensions.
The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def to_serial(self) -> ext_s.Package:
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/serialization/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue:
return extension.add_extension_value(

Check warning on line 70 in hugr-py/src/hugr/serialization/extension.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/extension.py#L70

Added line #L70 was not covered by tests
ext.ExtensionValue(
name=self.name,
typed_value=self.typed_value.deserialize(),
val=self.typed_value.deserialize(),
)
)

Expand Down

0 comments on commit af4f7b8

Please sign in to comment.