Skip to content

Commit

Permalink
ensure root model is contained in model_rebuild
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 1, 2024
1 parent 92a655e commit 339b522
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 30 deletions.
7 changes: 1 addition & 6 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
SumType,
TypeBound,
ConfiguredBaseModel,
_model_rebuild as tys_model_rebuild,
classes as tys_classes,
)

Expand Down Expand Up @@ -553,8 +552,4 @@ class OpDef(BaseOp, populate_by_name=True):
classes = inspect.getmembers(
sys.modules[__name__],
lambda member: inspect.isclass(member) and member.__module__ == __name__,
)


def model_rebuild(config: ConfigDict = ConfigDict(), **kwargs):
return tys_model_rebuild(classes + tys_classes, config, **kwargs)
) + tys_classes
10 changes: 8 additions & 2 deletions hugr-py/src/hugr/serialization/serial_hugr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict

from .ops import NodeID, OpType
from .ops import NodeID, OpType, classes
from .tys import model_rebuild
import hugr

Port = tuple[NodeID, int | None] # (node, offset)
Expand Down Expand Up @@ -34,6 +35,11 @@ def get_version(cls) -> str:
"""Return the version of the schema."""
return cls(nodes=[], edges=[]).version

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs)


class Config:
title = "Hugr"
json_schema_extra = {
Expand Down
9 changes: 7 additions & 2 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import ConfigDict
from typing import Literal
from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel
from .ops import Value, OpType
from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild
from .ops import Value, OpType, classes


class TestingHugr(ConfiguredBaseModel):
Expand All @@ -19,5 +20,9 @@ def get_version(cls) -> str:
"""Return the version of the schema."""
return cls().version

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs)

class Config:
title = "HugrTesting"
15 changes: 8 additions & 7 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union, Tuple
from typing import Annotated, Any, Literal, Optional, Union, Tuple, Any

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -346,23 +346,24 @@ class Signature(ConfiguredBaseModel):

# Now that all classes are defined, we need to update the ForwardRefs in all type
# annotations. We use some inspect magic to find all classes defined in this file
# and call _model_rebuild()
# and call model_rebuild()
classes = inspect.getmembers(
sys.modules[__name__],
lambda member: inspect.isclass(member) and member.__module__ == __name__,
)


def _model_rebuild(
classes: list[Tuple[str, type]] = classes,
def model_rebuild(
classes: list[tuple[str,Any]],
config: ConfigDict = ConfigDict(),
**kwargs,
):
new_config = ConfigDict(default_model_config, **config)
for _, c in classes:
new_config = default_model_config.copy()
new_config.update(config)
for c in {k: v for (k,v) in classes}.values():
if issubclass(c, ConfiguredBaseModel):
c.set_model_config(new_config)
c.model_rebuild(**kwargs)


_model_rebuild()
model_rebuild(classes)
24 changes: 13 additions & 11 deletions scripts/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@

import json
import sys
from typing import Type
from typing import Type, Optional
from pathlib import Path

from pydantic import ConfigDict

from hugr.serialization.ops import model_rebuild
from hugr.serialization import SerialHugr
from hugr.serialization.testing_hugr import TestingHugr

from hugr.serialization import tys


def write_schema(
out_dir: Path, name_prefix: str, schema: Type[SerialHugr] | Type[TestingHugr]
out_dir: Path, name_prefix: str, schema: Type[SerialHugr] | Type[TestingHugr], config: Optional[ConfigDict] = None, **kwargs
):

version = schema.get_version()
filename = f"{name_prefix}_{version}.json"
path = out_dir / filename
print(f"Rebuilding model with config: {config}")
schema._pydantic_rebuild(config or ConfigDict(), force=True, **kwargs)
print(f"Writing schema to {path}")
with path.open("w") as f:
json.dump(schema.model_json_schema(), f, indent=4)
Expand All @@ -39,11 +43,9 @@ def write_schema(
print(__doc__)
sys.exit(1)

model_rebuild(config=ConfigDict(strict=True, extra="forbid"), force=True)
write_schema(out_dir, "testing_hugr_schema_strict", TestingHugr)
model_rebuild(config=ConfigDict(strict=False, extra="allow"), force=True)
write_schema(out_dir, "testing_hugr_schema", TestingHugr)
model_rebuild(config=ConfigDict(strict=True, extra="forbid"), force=True)
write_schema(out_dir, "hugr_schema_strict", SerialHugr)
model_rebuild(config=ConfigDict(strict=False, extra="allow"), force=True)
write_schema(out_dir, "hugr_schema", SerialHugr)
strict_config = ConfigDict(strict=True, extra="forbid")
lax_config = ConfigDict(strict=False, extra="allow")
write_schema(out_dir, "testing_hugr_schema_strict", TestingHugr, config=strict_config)
write_schema(out_dir, "testing_hugr_schema", TestingHugr, config=lax_config)
write_schema(out_dir, "hugr_schema_strict", SerialHugr, config=strict_config)
write_schema(out_dir, "hugr_schema", SerialHugr, config=lax_config)
3 changes: 2 additions & 1 deletion specification/schema/testing_hugr_schema_strict_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,7 @@
"type": "object"
}
},
"additionalProperties": false,
"description": "A serializable representation of a Hugr Type, SumType, PolyFuncType,\nValue, OpType. Intended for testing only.",
"properties": {
"version": {
Expand Down Expand Up @@ -1913,6 +1914,6 @@
"default": null
}
},
"title": "HugrTesting",
"title": "TestingHugr",
"type": "object"
}
3 changes: 2 additions & 1 deletion specification/schema/testing_hugr_schema_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,7 @@
"type": "object"
}
},
"additionalProperties": true,
"description": "A serializable representation of a Hugr Type, SumType, PolyFuncType,\nValue, OpType. Intended for testing only.",
"properties": {
"version": {
Expand Down Expand Up @@ -1913,6 +1914,6 @@
"default": null
}
},
"title": "HugrTesting",
"title": "TestingHugr",
"type": "object"
}

0 comments on commit 339b522

Please sign in to comment.