Skip to content

Commit

Permalink
add tests for tys strings
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Aug 29, 2024
1 parent 2408784 commit 4e484c7
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 28 deletions.
7 changes: 7 additions & 0 deletions hugr-py/src/hugr/_serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ def join(*bs: TypeBound) -> TypeBound:
res = b
return res

def __str__(self) -> str:
match self:
case TypeBound.Copyable:
return "Copyable"

Check warning on line 401 in hugr-py/src/hugr/_serialization/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_serialization/tys.py#L401

Added line #L401 was not covered by tests
case TypeBound.Any:
return "Any"


class Opaque(BaseType):
"""An opaque Type that can be downcasted by the extensions that define it."""
Expand Down
67 changes: 47 additions & 20 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable

import hugr._serialization.tys as stys
from hugr.utils import ser_it
from hugr.utils import comma_sep_repr, comma_sep_str, ser_it

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Sequence

from hugr import ext

Expand Down Expand Up @@ -98,6 +98,9 @@ class TypeTypeParam(TypeParam):
def _to_serial(self) -> stys.TypeTypeParam:
return stys.TypeTypeParam(b=self.bound)

def __str__(self) -> str:
return str(self.bound)


@dataclass(frozen=True)
class BoundedNatParam(TypeParam):
Expand All @@ -108,6 +111,11 @@ class BoundedNatParam(TypeParam):
def _to_serial(self) -> stys.BoundedNatParam:
return stys.BoundedNatParam(bound=self.upper_bound)

def __str__(self) -> str:
if self.upper_bound is None:
return "Nat"
return f"Nat({self.upper_bound})"


@dataclass(frozen=True)
class StringParam(TypeParam):
Expand All @@ -116,6 +124,9 @@ class StringParam(TypeParam):
def _to_serial(self) -> stys.StringParam:
return stys.StringParam()

def __str__(self) -> str:
return "String"


@dataclass(frozen=True)
class ListParam(TypeParam):
Expand All @@ -126,6 +137,9 @@ class ListParam(TypeParam):
def _to_serial(self) -> stys.ListParam:
return stys.ListParam(param=self.param._to_serial_root())

def __str__(self) -> str:
return f"[{self.param}]"


@dataclass(frozen=True)
class TupleParam(TypeParam):
Expand All @@ -136,6 +150,9 @@ class TupleParam(TypeParam):
def _to_serial(self) -> stys.TupleParam:
return stys.TupleParam(params=ser_it(self.params))

def __str__(self) -> str:
return f"({comma_sep_str(self.params)})"


@dataclass(frozen=True)
class ExtensionsParam(TypeParam):
Expand All @@ -144,6 +161,9 @@ class ExtensionsParam(TypeParam):
def _to_serial(self) -> stys.ExtensionsParam:
return stys.ExtensionsParam()

def __str__(self) -> str:
return "Extensions"


# ------------------------------------------
# --------------- TypeArg ------------------
Expand All @@ -163,7 +183,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg:
return TypeTypeArg(self.ty.resolve(registry))

def __str__(self) -> str:
return str(self.ty)
return f"Type({self.ty!s})"


@dataclass(frozen=True)
Expand All @@ -189,7 +209,7 @@ def _to_serial(self) -> stys.StringArg:
return stys.StringArg(arg=self.value)

def __str__(self) -> str:
return f'{"self.value"}'
return f'"{self.value}"'


@dataclass(frozen=True)
Expand All @@ -205,7 +225,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg:
return SequenceArg([arg.resolve(registry) for arg in self.elems])

def __str__(self) -> str:
return f"({', '.join(str(arg) for arg in self.elems)})"
return f"({comma_sep_str(self.elems)})"


@dataclass(frozen=True)
Expand All @@ -218,7 +238,7 @@ def _to_serial(self) -> stys.ExtensionsArg:
return stys.ExtensionsArg(es=self.extensions)

def __str__(self) -> str:
return str(self.extensions)
return f"Extensions({comma_sep_str(self.extensions)})"


@dataclass(frozen=True)
Expand All @@ -232,7 +252,7 @@ def _to_serial(self) -> stys.VariableArg:
return stys.VariableArg(idx=self.idx, cached_decl=self.param._to_serial_root())

def __str__(self) -> str:
return f"VariableArg({self.idx})"
return f"${self.idx}"


# ----------------------------------------------
Expand All @@ -254,7 +274,7 @@ def type_bound(self) -> TypeBound:
return self.ty.type_bound()

def __repr__(self) -> str:
return f"Array({self.ty}, {self.size})"
return f"Array<{self.ty}, {self.size}>"


@dataclass()
Expand Down Expand Up @@ -338,7 +358,7 @@ def __init__(self, *tys: Type):
self.variant_rows = [list(tys), []]

def __repr__(self) -> str:
return f"Option({', '.join(map(repr, self.variant_rows[0]))})"
return f"Option({comma_sep_repr(self.variant_rows[0])})"


@dataclass(eq=False)
Expand Down Expand Up @@ -379,7 +399,7 @@ def type_bound(self) -> TypeBound:
return self.bound

def __repr__(self) -> str:
return f"Variable({self.idx})"
return f"${self.idx}"


@dataclass(frozen=True)
Expand All @@ -396,7 +416,7 @@ def type_bound(self) -> TypeBound:
return self.bound

def __repr__(self) -> str:
return f"RowVariable({self.idx})"
return f"${self.idx}"


@dataclass(frozen=True)
Expand Down Expand Up @@ -427,7 +447,7 @@ def type_bound(self) -> TypeBound:
return self.bound

def __repr__(self) -> str:
return f"Alias({self.name})"
return self.name


@dataclass(frozen=True)
Expand Down Expand Up @@ -495,8 +515,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType:
)

def __str__(self) -> str:
# [Qubit] -> [Bool]
return f"{self.input} -> {self.output})"
return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}"


@dataclass(frozen=True)
Expand Down Expand Up @@ -525,8 +544,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType:
)

def __str__(self) -> str:
# ∀[a]. [list<a>] -> [Bool]
return f"∀{self.params}. {self.body!s})"
return f"∀ {comma_sep_str(self.params)}. {self.body!s}"


@dataclass
Expand All @@ -551,17 +569,26 @@ def type_bound(self) -> TypeBound:
return TypeBound.join(*bounds)

def _to_serial(self) -> stys.Opaque:
return self._to_opaque()._to_serial()

def _to_opaque(self) -> Opaque:
assert self.type_def._extension is not None, "Extension must be initialised."

return stys.Opaque(
return Opaque(
extension=self.type_def._extension.name,
id=self.type_def.name,
args=[arg._to_serial_root() for arg in self.args],
args=self.args,
bound=self.type_bound(),
)

def __str__(self) -> str:
return f"{self.type_def.name}<{', '.join(str(arg) for arg in self.args)}>"
return _type_str(self.type_def.name, self.args)


def _type_str(name: str, args: Sequence[TypeArg]) -> str:
if len(args) == 0:
return name
return f"{name}<{comma_sep_str(args)}>"


@dataclass
Expand Down Expand Up @@ -599,7 +626,7 @@ def resolve(self, registry: ext.ExtensionRegistry) -> Type:
return ExtType(type_def, self.args)

def __str__(self) -> str:
return f"{self.id}<{', '.join(str(arg) for arg in self.args)}>"
return _type_str(self.id, self.args)


@dataclass
Expand Down
13 changes: 13 additions & 0 deletions hugr-py/src/hugr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,16 @@ def ser_it(it: Iterable[SerCollection[S]]) -> list[S]:
def deser_it(it: Iterable[DeserCollection[S]]) -> list[S]:
"""Deserialize an iterable of deserializable objects."""
return [v.deserialize() for v in it]


T = TypeVar("T")


def comma_sep_str(items: Iterable[T]) -> str:
"""Join items with commas and str."""
return ", ".join(map(str, items))


def comma_sep_repr(items: Iterable[T]) -> str:
"""Join items with commas and repr."""
return ", ".join(map(repr, items))
12 changes: 5 additions & 7 deletions hugr-py/src/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hugr._serialization.ops as sops
import hugr._serialization.tys as stys
from hugr import tys
from hugr.utils import ser_it
from hugr.utils import comma_sep_repr, comma_sep_str, ser_it

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -153,7 +153,7 @@ def _to_serial(self) -> sops.TupleValue: # type: ignore[override]
)

def __repr__(self) -> str:
return f"Tuple({', '.join(map(repr, self.vals))})"
return f"Tuple({comma_sep_repr(self.vals)})"


@dataclass(eq=False)
Expand All @@ -176,7 +176,7 @@ def __init__(self, *vals: Value):
)

def __repr__(self) -> str:
return f"Some({', '.join(map(repr, self.vals))})"
return f"Some({comma_sep_repr(self.vals)})"


@dataclass(eq=False)
Expand Down Expand Up @@ -229,8 +229,7 @@ def __repr__(self) -> str:
return f"Left(vals={self.vals}, right_typ={list(right_typ)})"

def __str__(self) -> str:
vals_str = ", ".join(map(str, self.vals))
return f"Left({vals_str})"
return f"Left({comma_sep_str(self.vals)})"


@dataclass(eq=False)
Expand Down Expand Up @@ -262,8 +261,7 @@ def __repr__(self) -> str:
return f"Right(left_typ={list(left_typ)}, vals={self.vals})"

def __str__(self) -> str:
vals_str = ", ".join(map(str, self.vals))
return f"Right({vals_str})"
return f"Right({comma_sep_str(self.vals)})"


@dataclass
Expand Down
Loading

0 comments on commit 4e484c7

Please sign in to comment.