Skip to content

Commit

Permalink
Merge pull request #355 from reagento/json-schema-continuation
Browse files Browse the repository at this point in the history
JSON schema continuation
  • Loading branch information
zhPavel authored Feb 9, 2025
2 parents 09ae5dd + f42fcd8 commit 7ab72d6
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 33 deletions.
22 changes: 22 additions & 0 deletions src/adaptix/_internal/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import (
Collection,
Hashable,
Expand All @@ -6,6 +7,7 @@
KeysView,
Mapping,
Reversible,
Sequence,
Set,
Sized,
ValuesView,
Expand Down Expand Up @@ -270,3 +272,23 @@ def reversed_slice(self: StackT, end_offset: int) -> StackT:

def count(self, item: T_co) -> int: # type: ignore[misc]
return sum(loc == item for loc in reversed(self))


ItemT = TypeVar("ItemT", bound=Hashable)


class OrderedUniqueGrouper(Generic[K, ItemT]):
__slots__ = ("_key_to_item_list", "_key_to_item_set")

def __init__(self):
self._key_to_item_list = defaultdict(list)
self._key_to_item_set = defaultdict(set)

def add(self, key: K, item: ItemT) -> None:
if item not in self._key_to_item_set[key]:
self._key_to_item_set[key].add(item)
self._key_to_item_list[key].append(item)

def finalize(self) -> Mapping[K, Sequence[ItemT]]:
self._key_to_item_list.default_factory = None
return self._key_to_item_list
52 changes: 51 additions & 1 deletion src/adaptix/_internal/morphing/facade/func.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TypeVar, overload

from ...common import TypeHint
from .retort import Retort
from ...definitions import Direction
from ..json_schema.definitions import ResolvedJSONSchema
from ..json_schema.mangling import CompoundRefMangler, IndexRefMangler, QualnameRefMangler
from ..json_schema.ref_generator import BuiltinRefGenerator
from ..json_schema.request_cls import JSONSchemaContext
from ..json_schema.resolver import BuiltinJSONSchemaResolver, JSONSchemaResolver
from ..json_schema.schema_model import JSONSchemaDialect
from .retort import AdornedRetort, Retort

_global_retort = Retort()
T = TypeVar("T")
Expand Down Expand Up @@ -33,3 +41,45 @@ def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any:

def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any:
return _global_retort.dump(data, tp)


_global_resolver = BuiltinJSONSchemaResolver(
ref_generator=BuiltinRefGenerator(),
ref_mangler=CompoundRefMangler(QualnameRefMangler(), IndexRefMangler()),
)


DumpedJSONSchema = Mapping[str, Any]


def generate_json_schemas(
retort: AdornedRetort,
tps: Iterable[TypeHint],
*,
direction: Direction,
resolver: JSONSchemaResolver = _global_resolver,
dialect: str = JSONSchemaDialect.DRAFT_2020_12,
) -> tuple[DumpedJSONSchema, Iterable[DumpedJSONSchema]]:
ctx = JSONSchemaContext(dialect=dialect, direction=direction)
defs, schemas = resolver.resolve((), [retort.make_json_schema(tp, ctx) for tp in tps])
dumped_defs = _global_retort.dump(defs, dict[str, ResolvedJSONSchema])
dumped_schemas = _global_retort.dump(schemas, Iterable[ResolvedJSONSchema])
return dumped_defs, dumped_schemas


def generate_json_schema(
retort: AdornedRetort,
tp: TypeHint,
*,
direction: Direction,
resolver: JSONSchemaResolver = _global_resolver,
dialect: str = JSONSchemaDialect.DRAFT_2020_12,
) -> Mapping[str, Any]:
defs, [schema] = generate_json_schemas(
retort,
[tp],
direction=direction,
resolver=resolver,
dialect=dialect,
)
return {**schema, "$defs": defs}
8 changes: 8 additions & 0 deletions src/adaptix/_internal/morphing/facade/retort.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
UnionProvider,
)
from ..iterable_provider import IterableProvider
from ..json_schema.definitions import JSONSchema
from ..json_schema.providers import InlineJSONSchemaProvider, JSONSchemaRefProvider
from ..json_schema.request_cls import JSONSchemaContext, JSONSchemaRequest
from ..model.crown_definitions import ExtraSkip
from ..model.dumper_provider import ModelDumperProvider
from ..model.loader_provider import ModelLoaderProvider
Expand Down Expand Up @@ -313,6 +315,12 @@ def dump(self, data: Any, tp: Optional[TypeHint] = None, /) -> Any:
)
return self.get_dumper(tp)(data)

def make_json_schema(self, tp: TypeHint, ctx: JSONSchemaContext) -> JSONSchema:
return self._facade_provide(
JSONSchemaRequest(loc_stack=LocStack(TypeHintLoc(type=tp)), ctx=ctx),
error_message=f"Cannot produce JSONSchema for type {tp!r}",
)


class Retort(FilledRetort, AdornedRetort):
pass
16 changes: 7 additions & 9 deletions src/adaptix/_internal/morphing/json_schema/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Generic, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union

from ...provider.loc_stack_filtering import LocStack
from ...type_tools.fwd_ref import FwdRef
Expand All @@ -11,23 +11,21 @@


@dataclass(frozen=True)
class JSONSchemaRef(Generic[JSONSchemaT]):
value: str
is_final: bool
json_schema: JSONSchemaT
class RefSource(Generic[JSONSchemaT]):
value: Optional[str]
json_schema: JSONSchemaT = field(hash=False)
loc_stack: LocStack = field(repr=False)

def __hash__(self):
return hash(self.value)


Boolable = Union[T, bool]


class JSONSchema(BaseJSONSchema[JSONSchemaRef[Boolable[FwdRef["JSONSchema"]]], Boolable[FwdRef["JSONSchema"]]]):
@dataclass(repr=False)
class JSONSchema(BaseJSONSchema[RefSource[FwdRef["JSONSchema"]], Boolable[FwdRef["JSONSchema"]]]):
pass


@dataclass(repr=False)
class ResolvedJSONSchema(BaseJSONSchema[str, Boolable[FwdRef["ResolvedJSONSchema"]]]):
pass

72 changes: 72 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/mangling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Container, Mapping, Sequence
from itertools import count
from typing import Optional

from ...datastructures import OrderedUniqueGrouper
from .definitions import RefSource
from .resolver import RefMangler


class IndexRefMangler(RefMangler):
def __init__(self, start: int = 1, separator: str = "-"):
self._start = start
self._separator = separator

def mangle_refs(
self,
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
result = {}
counter = count(self._start)
for source in sources:
while True:
idx = next(counter)
mangled = self._with_index(common_ref, idx)
if mangled not in occupied_refs:
result[source] = mangled
break

return result

def _with_index(self, common_ref: str, index: int) -> str:
return f"{common_ref}{self._separator}{index}"


class QualnameRefMangler(RefMangler):
def mangle_refs(
self,
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
return {source: self._generate_name(source) or common_ref for source in sources}

def _generate_name(self, source: RefSource) -> Optional[str]:
tp = source.loc_stack.last.type
return getattr(tp, "__qualname__", None)


class CompoundRefMangler(RefMangler):
def __init__(self, base: RefMangler, wrapper: RefMangler):
self._base = base
self._wrapper = wrapper

def mangle_refs(
self,
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
mangled = self._base.mangle_refs(occupied_refs, common_ref, sources)

grouper = OrderedUniqueGrouper[str, RefSource]()
for source, ref in mangled.items():
grouper.add(ref, source)

for ref, ref_sources in grouper.finalize().items():
if len(ref_sources) > 1:
mangled = {**mangled, **self._wrapper.mangle_refs(occupied_refs, ref, ref_sources)}

return mangled
23 changes: 7 additions & 16 deletions src/adaptix/_internal/morphing/json_schema/providers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ...provider.essential import Mediator
from ...provider.located_request import LocatedRequestMethodsProvider
from ...provider.methods_provider import method_handler
from .definitions import JSONSchemaRef
from .request_cls import InlineJSONSchemaRequest, JSONSchemaRefRequest
from .definitions import RefSource
from .request_cls import InlineJSONSchemaRequest, RefSourceRequest


class InlineJSONSchemaProvider(LocatedRequestMethodsProvider):
Expand All @@ -16,31 +16,22 @@ def provide_inline_json_schema(self, mediator: Mediator, request: InlineJSONSche

class JSONSchemaRefProvider(LocatedRequestMethodsProvider):
@method_handler
def provide_json_schema_ref(self, mediator: Mediator, request: JSONSchemaRefRequest) -> JSONSchemaRef:
return JSONSchemaRef(
value=self._get_reference_value(request),
is_final=False,
def provide_ref_source(self, mediator: Mediator, request: RefSourceRequest) -> RefSource:
return RefSource(
value=None,
json_schema=request.json_schema,
loc_stack=request.loc_stack,
)

def _get_reference_value(self, request: JSONSchemaRefRequest) -> str:
tp = request.loc_stack.last.type
try:
return tp.__name__
except AttributeError:
return str(tp)


class ConstantJSONSchemaRefProvider(LocatedRequestMethodsProvider):
def __init__(self, ref_value: str):
self._ref_value = ref_value

@method_handler
def provide_json_schema_ref(self, mediator: Mediator, request: JSONSchemaRefRequest) -> JSONSchemaRef:
return JSONSchemaRef(
def provide_ref_source(self, mediator: Mediator, request: RefSourceRequest) -> RefSource:
return RefSource(
value=self._ref_value,
is_final=True,
json_schema=request.json_schema,
loc_stack=request.loc_stack,
)
8 changes: 8 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/ref_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ...provider.loc_stack_filtering import LocStack
from .definitions import JSONSchema
from .resolver import RefGenerator


class BuiltinRefGenerator(RefGenerator):
def generate_ref(self, json_schema: JSONSchema, loc_stack: LocStack) -> str:
return str(loc_stack.last.type)
4 changes: 2 additions & 2 deletions src/adaptix/_internal/morphing/json_schema/request_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ...definitions import Direction
from ...provider.located_request import LocatedRequest
from .definitions import JSONSchema, JSONSchemaRef
from .definitions import JSONSchema, RefSource


@dataclass(frozen=True)
Expand All @@ -22,7 +22,7 @@ class JSONSchemaRequest(LocatedRequest[JSONSchema], WithJSONSchemaContext):


@dataclass(frozen=True)
class JSONSchemaRefRequest(LocatedRequest[JSONSchemaRef], WithJSONSchemaContext):
class RefSourceRequest(LocatedRequest[RefSource], WithJSONSchemaContext):
json_schema: JSONSchema


Expand Down
Loading

0 comments on commit 7ab72d6

Please sign in to comment.