Skip to content

Commit

Permalink
Add ParserConfig.class_factory (#549)
Browse files Browse the repository at this point in the history
Closes #548
  • Loading branch information
tefra authored Jul 8, 2021
1 parent 4eadb4a commit 02247bd
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: tests/fixtures

repos:
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
rev: v2.20.0
hooks:
- id: pyupgrade
args: [--py37-plus]
Expand Down
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Advance Topics
:maxdepth: 1

examples/custom-property-names
examples/custom-class-factory


Test Suites
Expand Down
35 changes: 35 additions & 0 deletions docs/examples/custom-class-factory.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
====================
Custom class factory
====================


It's not recommended to modify the generated models. If you need to add any pre/post
initialization logic or even validations you can use the parser config to override the
default class factory.

.. doctest::

>>> from dataclasses import dataclass
>>> from xsdata.formats.dataclass.parsers import JsonParser
>>> from xsdata.formats.dataclass.parsers.config import ParserConfig
...
>>> def custom_class_factory(clazz, params):
... if clazz.__name__ == "Person":
... return clazz(**{k: v.upper() for k, v in params.items()})
...
... return clazz(**params)
...

>>> config = ParserConfig(class_factory=custom_class_factory)
>>> parser = JsonParser(config=config)
...
>>> @dataclass
... class Person:
... first_name: str
... last_name: str
...
>>> json_str = """{"first_name": "chris", "last_name": "foo"}"""
...
...
>>> print(parser.from_string(json_str, Person))
Person(first_name='CHRIS', last_name='FOO')
51 changes: 41 additions & 10 deletions tests/formats/dataclass/parsers/nodes/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,47 @@
from xsdata.exceptions import ParserError
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.nodes import UnionNode
from xsdata.models.mixins import attribute
from xsdata.utils.testing import XmlVarFactory


class UnionNodeTests(TestCase):
def setUp(self) -> None:
super().setUp()

self.context = XmlContext()
self.config = ParserConfig()

def test_child(self):
attrs = {"id": "1"}
ns_map = {"ns0": "xsdata"}
ctx = XmlContext()
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo")
node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
node = UnionNode(
position=0,
var=var,
config=self.config,
context=self.context,
attrs={},
ns_map={},
)
self.assertEqual(node, node.child("foo", attrs, ns_map, 10))

self.assertEqual(1, node.level)
self.assertEqual([("start", "foo", attrs, ns_map)], node.events)
self.assertIsNot(attrs, node.events[0][2])

def test_bind_appends_end_event_when_level_not_zero(self):
ctx = XmlContext()
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo")
node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
node = UnionNode(
position=0,
var=var,
config=self.config,
context=self.context,
attrs={},
ns_map={},
)
node.level = 1
objects = []

Expand All @@ -43,12 +62,18 @@ def test_bind_returns_best_matching_object(self):
item2 = make_dataclass("Item2", [("a", int, attribute())])
root = make_dataclass("Root", [("item", Union[str, int, item2, item])])

ctx = XmlContext()
meta = ctx.build(root)
meta = self.context.build(root)
var = next(meta.find_children("item"))
attrs = {"a": "1", "b": 2}
ns_map = {}
node = UnionNode(position=0, var=var, context=ctx, attrs=attrs, ns_map=ns_map)
node = UnionNode(
position=0,
var=var,
config=self.config,
context=self.context,
attrs=attrs,
ns_map=ns_map,
)
objects = []

self.assertTrue(node.bind("item", "1", None, objects))
Expand All @@ -73,11 +98,17 @@ def test_bind_returns_best_matching_object(self):
self.assertEqual("a", objects[-1][1])

def test_bind_raises_parser_error_on_failure(self):
ctx = XmlContext()
meta = ctx.build(UnionType)
meta = self.context.build(UnionType)
var = next(meta.find_children("element"))

node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
node = UnionNode(
position=0,
var=var,
config=self.config,
context=self.context,
attrs={},
ns_map={},
)

with self.assertRaises(ParserError) as cm:
node.bind("element", None, None, [])
Expand Down
12 changes: 12 additions & 0 deletions xsdata/formats/dataclass/parsers/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Type
from typing import TypeVar

T = TypeVar("T")


def default_class_factory(cls: Type[T], params: Dict) -> T:
return cls(**params) # type: ignore


@dataclass
Expand All @@ -10,6 +20,7 @@ class ParserConfig:
:param base_url: Specify a base URL when parsing from memory and
you need support for relative links eg xinclude
:param process_xinclude: Enable xinclude statements processing
:param class_factory: Override default object instantiation
:param fail_on_unknown_properties: Skip unknown properties or
fail with exception
:param fail_on_converter_warnings: Turn converter warnings to
Expand All @@ -18,5 +29,6 @@ class ParserConfig:

base_url: Optional[str] = None
process_xinclude: bool = False
class_factory: Callable[[Type[T], Dict], T] = default_class_factory
fail_on_unknown_properties: bool = True
fail_on_converter_warnings: bool = False
25 changes: 12 additions & 13 deletions xsdata/formats/dataclass/parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type
if list_type != isinstance(data, list):
if list_type:
raise ParserError("Document is object, expected array")
else:
raise ParserError("Document is array, expected object")
raise ParserError("Document is array, expected object")

return clazz # type: ignore

Expand Down Expand Up @@ -128,7 +127,7 @@ def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T:
params[var.name] = self.bind_value(meta, var, value)

try:
return clazz(**params) # type: ignore
return self.config.class_factory(clazz, params)
except TypeError as e:
raise ParserError(e)

Expand Down Expand Up @@ -257,22 +256,22 @@ def bind_complex_type(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any:
if var.is_clazz_union:
# Union of dataclasses
return self.bind_best_dataclass(data, var.types)
elif var.elements:
if var.elements:
# Compound field with multiple choices
return self.bind_best_dataclass(data, var.element_types)
elif var.any_type or var.is_wildcard:
if var.any_type or var.is_wildcard:
# xs:anyType element, check all meta classes
return self.bind_best_dataclass(data, meta.element_types)
else:
assert var.clazz is not None

subclasses = set(self.context.get_subclasses(var.clazz))
if subclasses:
# field annotation is an abstract/base type
subclasses.add(var.clazz)
return self.bind_best_dataclass(data, subclasses)
assert var.clazz is not None

subclasses = set(self.context.get_subclasses(var.clazz))
if subclasses:
# field annotation is an abstract/base type
subclasses.add(var.clazz)
return self.bind_best_dataclass(data, subclasses)

return self.bind_dataclass(data, var.clazz)
return self.bind_dataclass(data, var.clazz)

def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> T:
"""Bind derived element entry point."""
Expand Down
5 changes: 3 additions & 2 deletions xsdata/formats/dataclass/parsers/nodes/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ElementNode(XmlNode):
:param context: Model context provider
:param position: The node position of objects cache
:param mixed: The node supports mixed content
:param derived: The xml element is derived from a base type
:param derived_factory: Derived element factory
:param xsi_type: The xml type substitution
"""

Expand Down Expand Up @@ -92,7 +92,7 @@ def bind(
if isinstance(params[key], PendingCollection):
params[key] = params[key].evaluate()

obj = self.meta.clazz(**params)
obj = self.config.class_factory(self.meta.clazz, params)
if self.derived_factory:
obj = self.derived_factory(qname=qname, value=obj, type=self.xsi_type)

Expand Down Expand Up @@ -330,6 +330,7 @@ def build_node(
var=var,
attrs=attrs,
ns_map=ns_map,
config=self.config,
context=self.context,
position=position,
)
Expand Down
1 change: 1 addition & 0 deletions xsdata/formats/dataclass/parsers/nodes/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PrimitiveNode(XmlNode):
:param var: Class field xml var instance
:param ns_map: Namespace prefix-URI map
:param derived_factory: Derived element factory
"""

__slots__ = "var", "ns_map", "derived_factory"
Expand Down
3 changes: 1 addition & 2 deletions xsdata/formats/dataclass/parsers/nodes/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ class StandardNode(XmlNode):
:param datatype: Standard xsi data type
:param ns_map: Namespace prefix-URI map
:param derived: Specify whether the value needs to be wrapped with
:class:`~xsdata.formats.dataclass.models.generics.DerivedElement`
:param nillable: Specify whether the node supports nillable content
:param derived_factory: Optional derived element factory
"""

__slots__ = "datatype", "ns_map", "nillable", "derived_factory"
Expand Down
26 changes: 23 additions & 3 deletions xsdata/formats/dataclass/parsers/nodes/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.models.elements import XmlVar
from xsdata.formats.dataclass.parsers.bases import NodeParser
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.mixins import EventsHandler
from xsdata.formats.dataclass.parsers.mixins import XmlNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
Expand All @@ -32,18 +33,35 @@ class UnionNode(XmlNode):
:param attrs: Key-value attribute mapping
:param ns_map: Namespace prefix-URI map
:param position: The node position of objects cache
:param config: Parser configuration
:param context: Model context provider
"""

__slots__ = "var", "attrs", "ns_map", "position", "context", "level", "events"
__slots__ = (
"var",
"attrs",
"ns_map",
"position",
"config",
"context",
"level",
"events",
)

def __init__(
self, var: XmlVar, attrs: Dict, ns_map: Dict, position: int, context: XmlContext
self,
var: XmlVar,
attrs: Dict,
ns_map: Dict,
position: int,
config: ParserConfig,
context: XmlContext,
):
self.var = var
self.attrs = attrs
self.ns_map = ns_map
self.position = position
self.config = config
self.context = context
self.level = 0
self.events: List[Tuple[str, str, Any, Any]] = []
Expand Down Expand Up @@ -94,7 +112,9 @@ def parse_class(self, clazz: Type[T]) -> Optional[T]:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConverterWarning)

parser = NodeParser(context=self.context, handler=EventsHandler)
parser = NodeParser(
config=self.config, context=self.context, handler=EventsHandler
)
return parser.parse(self.events, clazz)
except Exception:
return None
Expand Down
1 change: 1 addition & 0 deletions xsdata/formats/dataclass/parsers/nodes/wildcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class WildcardNode(XmlNode):
:param attrs: Key-value attribute mapping
:param ns_map: Namespace prefix-URI map
:param position: The node position of objects cache
:param factory: Wildcard element factory
"""

__slots__ = "var", "attrs", "ns_map", "position", "factory"
Expand Down

0 comments on commit 02247bd

Please sign in to comment.