Skip to content

Commit

Permalink
Unmarshaller format refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Oct 11, 2022
1 parent 3ffe2bc commit 692a915
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/customizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Here's how you could add support for a ``usdate`` format that handles dates of t
def validate(self, value) -> bool:
return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value))
def unmarshal(self, value):
def format(self, value):
return datetime.strptime(value, "%m/%d/%y").date
Expand Down
8 changes: 0 additions & 8 deletions openapi_core/unmarshalling/schemas/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,6 @@ def create(
klass = self.UNMARSHALLERS[schema_type]
return klass(schema, validator, formatter)

def get_formatter(
self, type_format: str, default_formatters: FormattersDict
) -> Optional[Formatter]:
try:
return self.custom_formatters[type_format]
except KeyError:
return default_formatters.get(type_format)

def get_validator(self, schema: Spec) -> Validator:
resolver = schema.accessor.resolver # type: ignore
custom_format_checks = {
Expand Down
40 changes: 35 additions & 5 deletions openapi_core/unmarshalling/schemas/formatters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any
from typing import Callable
from typing import Optional
Expand All @@ -8,20 +9,49 @@ class Formatter:
def validate(self, value: Any) -> bool:
return True

def unmarshal(self, value: Any) -> Any:
def format(self, value: Any) -> Any:
return value

def __getattribute__(self, name: str) -> Any:
if name == "unmarshal":
warnings.warn(
"Unmarshal method is deprecated. " "Use format instead.",
DeprecationWarning,
)
return super().__getattribute__("format")
if name == "format":
try:
attr = super().__getattribute__("unmarshal")
except AttributeError:
return super().__getattribute__("format")
else:
warnings.warn(
"Unmarshal method is deprecated. "
"Rename unmarshal method to format instead.",
DeprecationWarning,
)
return attr
return super().__getattribute__(name)

@classmethod
def from_callables(
cls,
validate: Optional[Callable[[Any], Any]] = None,
validate_callable: Optional[Callable[[Any], Any]] = None,
format_callable: Optional[Callable[[Any], Any]] = None,
unmarshal: Optional[Callable[[Any], Any]] = None,
) -> "Formatter":
attrs = {}
if validate is not None:
attrs["validate"] = staticmethod(validate)
if validate_callable is not None:
attrs["validate"] = staticmethod(validate_callable)
if format_callable is not None:
attrs["format"] = staticmethod(format_callable)
if unmarshal is not None:
attrs["unmarshal"] = staticmethod(unmarshal)
warnings.warn(
"Unmarshal parameter is deprecated. "
"Use format_callable instead.",
DeprecationWarning,
)
attrs["format"] = staticmethod(unmarshal)

klass: Type[Formatter] = type("Formatter", (cls,), attrs)
return klass()
107 changes: 54 additions & 53 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import cast
Expand Down Expand Up @@ -31,6 +32,7 @@
)
from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue
from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError
from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError
from openapi_core.unmarshalling.schemas.exceptions import ValidateError
from openapi_core.unmarshalling.schemas.formatters import Formatter
from openapi_core.unmarshalling.schemas.util import format_byte
Expand Down Expand Up @@ -61,24 +63,25 @@ def __init__(
):
self.schema = schema
self.validator = validator
self.format = schema.getkey("format")
self.schema_format = schema.getkey("format")

if formatter is None:
if self.format not in self.FORMATTERS:
raise FormatterNotFoundError(self.format)
self.formatter = self.FORMATTERS[self.format]
if self.schema_format not in self.FORMATTERS:
raise FormatterNotFoundError(self.schema_format)
self.formatter = self.FORMATTERS[self.schema_format]
else:
self.formatter = formatter

def __call__(self, value: Any) -> Any:
if value is None:
return

self.validate(value)

# skip unmarshalling for nullable in OpenAPI 3.0
if value is None and self.schema.getkey("nullable", False):
return value

return self.unmarshal(value)

def _formatter_validate(self, value: Any) -> None:
def _validate_format(self, value: Any) -> None:
result = self.formatter.validate(value)
if not result:
schema_type = self.schema.getkey("type", "any")
Expand All @@ -91,11 +94,14 @@ def validate(self, value: Any) -> None:
schema_type = self.schema.getkey("type", "any")
raise InvalidSchemaValue(value, schema_type, schema_errors=errors)

def unmarshal(self, value: Any) -> Any:
def format(self, value: Any) -> Any:
try:
return self.formatter.unmarshal(value)
except ValueError as exc:
raise InvalidSchemaFormatValue(value, self.format, exc)
return self.formatter.format(value)
except (ValueError, TypeError) as exc:
raise InvalidSchemaFormatValue(value, self.schema_format, exc)

def unmarshal(self, value: Any) -> Any:
return self.format(value)


class StringUnmarshaller(BaseSchemaUnmarshaller):
Expand Down Expand Up @@ -192,10 +198,8 @@ def items_unmarshaller(self) -> "BaseSchemaUnmarshaller":
items_schema = self.schema.get("items", Spec.from_dict({}))
return self.unmarshallers_factory.create(items_schema)

def __call__(self, value: Any) -> Optional[List[Any]]:
value = super().__call__(value)
if value is None and self.schema.getkey("nullable", False):
return None
def unmarshal(self, value: Any) -> Optional[List[Any]]:
value = super().unmarshal(value)
return list(map(self.items_unmarshaller, value))


Expand All @@ -210,38 +214,31 @@ def object_class_factory(self) -> ModelPathFactory:
return ModelPathFactory()

def unmarshal(self, value: Any) -> Any:
properties = self.unmarshal_raw(value)
properties = self.format(value)

fields: Iterable[str] = properties and properties.keys() or []
object_class = self.object_class_factory.create(self.schema, fields)

return object_class(**properties)

def unmarshal_raw(self, value: Any) -> Any:
try:
value = self.formatter.unmarshal(value)
except ValueError as exc:
schema_format = self.schema.getkey("format")
raise InvalidSchemaFormatValue(value, schema_format, exc)
else:
return self._unmarshal_object(value)
def format(self, value: Any) -> Any:
formatted = super().format(value)
return self._unmarshal_properties(formatted)

def _clone(self, schema: Spec) -> "ObjectUnmarshaller":
return cast(
"ObjectUnmarshaller",
self.unmarshallers_factory.create(schema, "object"),
)

def _unmarshal_object(self, value: Any) -> Any:
def _unmarshal_properties(self, value: Any) -> Any:
properties = {}

if "oneOf" in self.schema:
one_of_properties = None
for one_of_schema in self.schema / "oneOf":
try:
unmarshalled = self._clone(one_of_schema).unmarshal_raw(
value
)
unmarshalled = self._clone(one_of_schema).format(value)
except (UnmarshalError, ValueError):
pass
else:
Expand All @@ -259,9 +256,7 @@ def _unmarshal_object(self, value: Any) -> Any:
any_of_properties = None
for any_of_schema in self.schema / "anyOf":
try:
unmarshalled = self._clone(any_of_schema).unmarshal_raw(
value
)
unmarshalled = self._clone(any_of_schema).format(value)
except (UnmarshalError, ValueError):
pass
else:
Expand Down Expand Up @@ -319,21 +314,36 @@ def types_unmarshallers(self) -> List["BaseSchemaUnmarshaller"]:
unmarshaller = partial(self.unmarshallers_factory.create, self.schema)
return list(map(unmarshaller, types))

def unmarshal(self, value: Any) -> Any:
for unmarshaller in self.types_unmarshallers:
@property
def type(self) -> List[str]:
types = self.schema.getkey("type", ["any"])
assert isinstance(types, list)
return types

def _get_unmarshallers_iter(self) -> Iterator["BaseSchemaUnmarshaller"]:
for schema_type in self.type:
yield self.unmarshallers_factory.create(
self.schema, type_override=schema_type
)

def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller":
for unmarshaller in self._get_unmarshallers_iter():
# validate with validator of formatter (usualy type validator)
try:
unmarshaller._formatter_validate(value)
unmarshaller._validate_format(value)
except ValidateError:
continue
else:
return unmarshaller(value)
return unmarshaller

log.warning("failed to unmarshal multi type")
return value
raise UnmarshallerError("Unmarshaller not found for type(s)")

def unmarshal(self, value: Any) -> Any:
unmarshaller = self._get_best_unmarshaller(value)
return unmarshaller(value)


class AnyUnmarshaller(ComplexUnmarshaller):
class AnyUnmarshaller(MultiTypeUnmarshaller):

SCHEMA_TYPES_ORDER = [
"object",
Expand All @@ -344,6 +354,10 @@ class AnyUnmarshaller(ComplexUnmarshaller):
"string",
]

@property
def type(self) -> List[str]:
return self.SCHEMA_TYPES_ORDER

def unmarshal(self, value: Any) -> Any:
one_of_schema = self._get_one_of_schema(value)
if one_of_schema:
Expand All @@ -357,20 +371,7 @@ def unmarshal(self, value: Any) -> Any:
if all_of_schema:
return self.unmarshallers_factory.create(all_of_schema)(value)

for schema_type in self.SCHEMA_TYPES_ORDER:
unmarshaller = self.unmarshallers_factory.create(
self.schema, type_override=schema_type
)
# validate with validator of formatter (usualy type validator)
try:
unmarshaller._formatter_validate(value)
except ValidateError:
continue
else:
return unmarshaller(value)

log.warning("failed to unmarshal any type")
return value
return super().unmarshal(value)

def _get_one_of_schema(self, value: Any) -> Optional[Spec]:
if "oneOf" not in self.schema:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/unmarshalling/test_unmarshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_array_null(self, unmarshaller_factory):
spec = Spec.from_dict(schema)
value = None

with pytest.raises(TypeError):
with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(spec)(value)

def test_array_nullable(self, unmarshaller_factory):
Expand Down

0 comments on commit 692a915

Please sign in to comment.