diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 6d1c93573..835e9a939 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -32,7 +32,15 @@ @dataclass -class ExplicitBound: # noqa: D101 +class ExplicitBound: + """An explicit type bound on an :class:`OpDef`. + + + Examples: + >>> ExplicitBound(tys.TypeBound.Copyable) + ExplicitBound(bound=) + """ + bound: tys.TypeBound def to_serial(self) -> ext_s.ExplicitBound: @@ -43,7 +51,16 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class FromParamsBound: # noqa: D101 +class FromParamsBound: + """Calculate the type bound of an :class:`OpDef` from the join of its parameters at + the given indices. + + + Examples: + >>> FromParamsBound(indices=[0, 1]) + FromParamsBound(indices=[0, 1]) + """ + indices: list[int] def to_serial(self) -> ext_s.FromParamsBound: @@ -54,10 +71,28 @@ def to_serial_root(self) -> ext_s.TypeDefBound: @dataclass -class TypeDef: # noqa: D101 +class TypeDef: + """Type definition in an :class:`Extension`. + + + Examples: + >>> td = TypeDef( + ... name="MyType", + ... description="A type definition.", + ... params=[tys.TypeTypeParam(tys.Bool)], + ... bound=ExplicitBound(tys.TypeBound.Copyable), + ... ) + >>> td.name + 'MyType' + """ + + #: The name of the type. name: str + #: A description of the type. description: str + #: The type parameters of the type if polymorphic. params: list[tys.TypeParam] + #: The type bound of the type. bound: ExplicitBound | FromParamsBound _extension: Extension | None = field( default=None, init=False, repr=False, compare=False @@ -74,12 +109,21 @@ def to_serial(self) -> ext_s.TypeDef: ) def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType: + """Instantiate a concrete type from this type definition. + + Args: + args: Type arguments corresponding to the type parameters of the definition. + """ return tys.ExtType(self, list(args)) @dataclass -class FixedHugr: # noqa: D101 +class FixedHugr: + """A HUGR used to define lowerings of operations in an :class:`OpDef`.""" + + #: Extensions used in the HUGR. extensions: tys.ExtensionSet + #: HUGR defining operation lowering. hugr: Hugr def to_serial(self) -> ext_s.FixedHugr: @@ -87,8 +131,12 @@ def to_serial(self) -> ext_s.FixedHugr: @dataclass -class OpDefSig: # noqa: D101 +class OpDefSig: + """Type signature of an :class:`OpDef`.""" + + #: The polymorphic function type of the operation (type scheme). poly_func: tys.PolyFuncType | None + #: If no static type scheme known, flag indidcates a computation of the signature. binary: bool def __init__( @@ -109,11 +157,18 @@ def __init__( @dataclass -class OpDef: # noqa: D101 +class OpDef: + """Operation definition in an :class:`Extension`.""" + + #: The name of the operation. name: str + #: The type signature of the operation. signature: OpDefSig + #: A description of the operation. description: str = "" + #: Miscellaneous information about the operation. misc: dict[str, Any] = field(default_factory=dict) + #: Lowerings of the operation. lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False) _extension: Extension | None = field( default=None, init=False, repr=False, compare=False @@ -135,9 +190,13 @@ def to_serial(self) -> ext_s.OpDef: @dataclass -class ExtensionValue: # noqa: D101 +class ExtensionValue: + """A value defined in an :class:`Extension`.""" + + #: The name of the value. name: str - typed_value: val.Value + #: Value payload. + val: val.Value _extension: Extension | None = field( default=None, init=False, repr=False, compare=False ) @@ -147,7 +206,7 @@ def to_serial(self) -> ext_s.ExtensionValue: return ext_s.ExtensionValue( extension=self._extension.name, name=self.name, - typed_value=self.typed_value.to_serial_root(), + typed_value=self.val.to_serial_root(), ) @@ -155,12 +214,20 @@ def to_serial(self) -> ext_s.ExtensionValue: @dataclass -class Extension: # noqa: D101 +class Extension: + """HUGR extension declaration.""" + + #: The name of the extension. name: ExtensionId + #: The version of the extension. version: Version + #: Extensions required by this extension, identified by name. extension_reqs: set[ExtensionId] = field(default_factory=set) + #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) + #: Values defined in the extension. values: dict[str, ExtensionValue] = field(default_factory=dict) + #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @dataclass @@ -180,16 +247,40 @@ def to_serial(self) -> ext_s.Extension: ) def add_op_def(self, op_def: OpDef) -> OpDef: + """Add an operation definition to the extension. + + Args: + op_def: The operation definition to add. + + Returns: + The added operation definition, now associated with the extension. + """ op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] def add_type_def(self, type_def: TypeDef) -> TypeDef: + """Add a type definition to the extension. + + Args: + type_def: The type definition to add. + + Returns: + The added type definition, now associated with the extension. + """ type_def._extension = self self.types[type_def.name] = type_def return self.types[type_def.name] def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: + """Add a value to the extension. + + Args: + extension_value: The value to add. + + Returns: + The added value, now associated with the extension. + """ extension_value._extension = self self.values[extension_value.name] = extension_value return self.values[extension_value.name] @@ -199,6 +290,17 @@ class OperationNotFound(NotFound): """Operation not found in extension.""" def get_op(self, name: str) -> OpDef: + """Retrieve an operation definition by name. + + Args: + name: The name of the operation. + + Returns: + The operation definition. + + Raises: + OperationNotFound: If the operation is not found in the extension. + """ try: return self.operations[name] except KeyError as e: @@ -209,6 +311,17 @@ class TypeNotFound(NotFound): """Type not found in extension.""" def get_type(self, name: str) -> TypeDef: + """Retrieve a type definition by name. + + Args: + name: The name of the type. + + Returns: + The type definition. + + Raises: + TypeNotFound: If the type is not found in the extension. + """ try: return self.types[name] except KeyError as e: @@ -266,17 +379,35 @@ def _inner(cls: type[T]) -> type[T]: @dataclass class ExtensionRegistry: + """Registry of extensions.""" + + #: Extensions in the registry, indexed by name. extensions: dict[ExtensionId, Extension] = field(default_factory=dict) @dataclass class ExtensionNotFound(Exception): + """Extension not found in registry.""" + extension_id: ExtensionId @dataclass class ExtensionExists(Exception): + """Extension already exists in registry.""" + extension_id: ExtensionId def add_extension(self, extension: Extension) -> Extension: + """Add an extension to the registry. + + Args: + extension: The extension to add. + + Returns: + The added extension. + + Raises: + ExtensionExists: If an extension with the same name already exists. + """ if extension.name in self.extensions: raise self.ExtensionExists(extension.name) # TODO version updates @@ -284,6 +415,17 @@ def add_extension(self, extension: Extension) -> Extension: return self.extensions[extension.name] def get_extension(self, name: ExtensionId) -> Extension: + """Retrieve an extension by name. + + Args: + name: The name of the extension. + + Returns: + Extension in the registry. + + Raises: + ExtensionNotFound: If the extension is not found in the registry. + """ try: return self.extensions[name] except KeyError as e: @@ -291,8 +433,16 @@ def get_extension(self, name: ExtensionId) -> Extension: @dataclass -class Package: # noqa: D101 +class Package: + """A package of HUGR modules and extensions. + + + The HUGRs may refer to the included extensions or those not included. + """ + + #: HUGR modules in the package. modules: list[Hugr] + #: Extensions included in the package. extensions: list[Extension] = field(default_factory=list) def to_serial(self) -> ext_s.Package: diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/extension.py index ff90dfa80..25b0943d6 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/extension.py @@ -70,7 +70,7 @@ def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: return extension.add_extension_value( ext.ExtensionValue( name=self.name, - typed_value=self.typed_value.deserialize(), + val=self.typed_value.deserialize(), ) )