Skip to content

Commit

Permalink
Resolve #306 Lookup xsi type for xs:anyType derived elements
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Nov 4, 2020
1 parent fbb9a9b commit 09e38cb
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.3.5
rev: v2.3.6
hooks:
- id: reorder-python-imports
- repo: https://github.com/ambv/black
Expand Down
114 changes: 113 additions & 1 deletion tests/formats/dataclass/parsers/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.mixins import XmlHandler
from xsdata.formats.dataclass.parsers.nodes import AnyTypeNode
from xsdata.formats.dataclass.parsers.nodes import ElementNode
from xsdata.formats.dataclass.parsers.nodes import NodeParser
from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode
from xsdata.formats.dataclass.parsers.nodes import SkipNode
from xsdata.formats.dataclass.parsers.nodes import UnionNode
from xsdata.formats.dataclass.parsers.nodes import WildcardNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import DataType
from xsdata.models.enums import QNames
from xsdata.models.mixins import attribute
from xsdata.models.mixins import element
Expand Down Expand Up @@ -355,7 +357,43 @@ def test_build_node_with_dataclass_var_validates_nillable(self, mock_ctx_fetch):
attrs = {QNames.XSI_NIL: "false"}
self.assertIsNone(self.node.build_node(var, attrs, ns_map, 10))

def test_build_node_with_any_type_var(self):
def test_build_node_with_any_type_var_with_matching_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])

actual = self.node.build_node(var, {QNames.XSI_TYPE: "Foo"}, {}, 10)

self.assertIsInstance(actual, ElementNode)
self.assertEqual(10, actual.position)
self.assertEqual(self.context.build(Foo), actual.meta)
self.assertEqual({QNames.XSI_TYPE: "Foo"}, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_any_type_var_with_no_matching_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])
attrs = {QNames.XSI_TYPE: "bar"}
actual = self.node.build_node(var, attrs, {}, 10)

self.assertIsInstance(actual, AnyTypeNode)
self.assertEqual(10, actual.position)
self.assertEqual(var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_any_type_var_with_no_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])
attrs = {}
actual = self.node.build_node(var, attrs, {}, 10)

self.assertIsInstance(actual, AnyTypeNode)
self.assertEqual(10, actual.position)
self.assertEqual(var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_wildcard_var(self):
var = XmlWildcard(name="a", qname="a", types=[], dataclass=False)

actual = self.node.build_node(var, {}, {}, 10)
Expand All @@ -375,6 +413,80 @@ def test_build_node_with_primitive_var(self):
self.assertEqual(ns_map, actual.ns_map)


class AnyTypeNodeTests(TestCase):
def setUp(self) -> None:
self.var = XmlElement(name="a", qname="a", types=[object])
self.node = AnyTypeNode(position=0, var=self.var, attrs={}, ns_map={})

def test_child(self):
self.assertFalse(self.node.has_children)

attrs = {"a": 1}
ns_map = {"ns0": "b"}
actual = self.node.child("foo", attrs, ns_map, 10)

self.assertIsInstance(actual, WildcardNode)
self.assertEqual(10, actual.position)
self.assertEqual(self.var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual(ns_map, actual.ns_map)
self.assertTrue(self.node.has_children)

def test_bind_with_children(self):
text = "\n "
tail = "bar"
generic = AnyElement(
qname="a",
text=None,
tail="bar",
ns_map={},
attributes={},
children=[1, 2, 3],
)

objects = [("a", 1), ("b", 2), ("c", 3)]

self.node.has_children = True
self.assertTrue(self.node.bind("a", text, tail, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(generic, objects[-1][1])

def test_bind_with_simple_type(self):
objects = []

self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", None, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(10.0, objects[-1][1])

def test_bind_with_simple_type_derived(self):
objects = []

self.node.var = XmlElement(name="a", qname="a", types=[object], derived=True)
self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", None, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(DerivedElement(qname="a", value=10.0), objects[-1][1])

def test_bind_with_simple_type_with_mixed_content(self):
objects = []

self.node.mixed = True
self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", "pieces", objects))
self.assertEqual(self.var.qname, objects[-2][0])
self.assertEqual(10.0, objects[-2][1])
self.assertIsNone(objects[-1][0])
self.assertEqual("pieces", objects[-1][1])

self.assertTrue(self.node.bind("a", "10", "\n", objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(10.0, objects[-1][1])


class WildcardNodeTests(TestCase):
def test_bind(self):
text = "\n "
Expand Down
3 changes: 0 additions & 3 deletions tests/formats/dataclass/serializers/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def test_write_mixed_content(self):
(XmlWriterEvent.START, "br"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.END, "br"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.START, "{xsdata}BookForm"),
(XmlWriterEvent.ATTR, "id", "123"),
(XmlWriterEvent.ATTR, "lang", "en"),
Expand Down Expand Up @@ -202,7 +201,6 @@ def test_write_any_type_with_generic_object(self):
(XmlWriterEvent.ATTR, "e", 2),
(XmlWriterEvent.DATA, "b"),
(XmlWriterEvent.DATA, "g"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.DATA, "h"),
(XmlWriterEvent.END, "a"),
(XmlWriterEvent.DATA, "c"),
Expand Down Expand Up @@ -324,7 +322,6 @@ def test_write_choice_with_generic_object(self):
(XmlWriterEvent.START, "a"),
(XmlWriterEvent.DATA, "1"),
(XmlWriterEvent.END, "a"),
(XmlWriterEvent.DATA, None),
]

result = self.serializer.write_value(value, var, "xsdata")
Expand Down
4 changes: 0 additions & 4 deletions tests/formats/dataclass/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ def test_property_is_wildcard(self):
self.assertIsInstance(var, XmlVar)
self.assertTrue(var.is_wildcard)

def test_property_is_any_type(self):
var = XmlWildcard(name="foo", qname="foo")
self.assertTrue(var.is_any_type)

def test_matches(self):
var = XmlWildcard(name="foo", qname="foo")
self.assertTrue(var.matches("*"))
Expand Down
23 changes: 19 additions & 4 deletions xsdata/formats/dataclass/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import sys
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -57,13 +58,24 @@ def fetch(
lookup procedure needs to check and match a dataclass model to the qualified
name instead.
"""

meta = self.build(clazz, parent_ns)
subclass = None
if xsi_type and meta.source_qname != xsi_type:
subclass = self.find_subclass(clazz, xsi_type)

return self.build(subclass, parent_ns) if subclass else meta

subclass = self.find_subclass(clazz, xsi_type) if xsi_type else None
if subclass:
meta = self.build(subclass, parent_ns)
def find_type(self, clazz: Type, xsi_type: str) -> Optional[Type]:
"""Scan the clazz module for all dataclasses and match against the
given xsi type."""
module = inspect.getmodule(clazz)
for name in dir(module):
member = getattr(module, name)
if self.match_class_source_qname(member, xsi_type):
return member

return meta
return None

def find_subclass(self, clazz: Type, xsi_type: str) -> Optional[Type]:
"""
Expand Down Expand Up @@ -199,6 +211,9 @@ def build_choices(
qname = build_qname(default_namespace, choice.get("name", "any"))
nillable = choice.get("nillable", False)

if xml_type == XmlType.ELEMENT and len(types) == 1 and types[0] == object:
derived = True

yield xml_clazz(
name=parent_name,
qname=qname,
Expand Down
7 changes: 3 additions & 4 deletions xsdata/formats/dataclass/models/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def is_mixed_content(self) -> bool:
def is_wildcard(self) -> bool:
return True

@property
def is_any_type(self) -> bool:
return True

def matches(self, qname: str) -> bool:
"""Match the given qname to the wildcard allowed namespaces."""

Expand Down Expand Up @@ -304,6 +300,9 @@ class XmlMeta:
def namespace(self) -> Optional[str]:
return split_qname(self.qname)[0]

def has_var(self, qname: str = "*", mode: FindMode = FindMode.ALL) -> bool:
return self.find_var(qname, mode) is not None

def find_var(
self, qname: str = "*", mode: FindMode = FindMode.ALL
) -> Optional[XmlVar]:
Expand Down
113 changes: 96 additions & 17 deletions xsdata/formats/dataclass/parsers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,23 @@ def build_node(
)

if var.clazz:
xsi_type = ParserUtils.xsi_type(attrs, ns_map)
is_nillable = ParserUtils.is_nillable(attrs)
meta = self.context.fetch(var.clazz, self.meta.namespace, xsi_type)
mixed = self.meta.find_var(mode=FindMode.MIXED_CONTENT)

if not is_nillable and meta.nillable:
return None

return ElementNode(
meta=meta,
config=self.config,
attrs=attrs,
ns_map=ns_map,
context=self.context,
position=position,
derived=var.derived,
mixed=mixed is not None,
return self.build_element_node(
var.clazz, attrs, ns_map, position, var.derived
)

if var.is_any_type:
node = self.build_element_node(None, attrs, ns_map, position, var.derived)
if not node:
node = AnyTypeNode(
var=var,
attrs=attrs,
ns_map=ns_map,
position=position,
mixed=self.meta.has_var(mode=FindMode.MIXED_CONTENT),
)
return node

if var.is_wildcard:
return WildcardNode(var=var, attrs=attrs, ns_map=ns_map, position=position)

return PrimitiveNode(var=var, ns_map=ns_map)
Expand All @@ -170,6 +167,88 @@ def fetch_vars(self, qname: str) -> Iterator[XmlVar]:
else:
yield var

def build_element_node(
self,
clazz: Optional[Type],
attrs: Dict,
ns_map: Dict,
position: int,
derived: bool,
) -> Optional[XmlNode]:
xsi_type = ParserUtils.xsi_type(attrs, ns_map)

if clazz is None:
if not xsi_type:
return None

clazz = self.context.find_type(self.meta.clazz, xsi_type)
xsi_type = None

if clazz is None:
return None

is_nillable = ParserUtils.is_nillable(attrs)
meta = self.context.fetch(clazz, self.meta.namespace, xsi_type)

if not is_nillable and meta.nillable:
return None

return ElementNode(
meta=meta,
config=self.config,
attrs=attrs,
ns_map=ns_map,
context=self.context,
position=position,
derived=derived,
mixed=self.meta.has_var(mode=FindMode.MIXED_CONTENT),
)


@dataclass
class AnyTypeNode(XmlNode):

var: XmlVar
attrs: Dict
ns_map: Dict
position: int
mixed: bool = False
has_children: bool = field(init=False, default=False)

def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> "XmlNode":
self.has_children = True
return WildcardNode(position=position, var=self.var, attrs=attrs, ns_map=ns_map)

def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool:
obj: Any = None
if self.has_children:
obj = AnyElement(
qname=qname,
text=ParserUtils.normalize_content(text),
tail=ParserUtils.normalize_content(tail),
ns_map=self.ns_map,
attributes=ParserUtils.parse_any_attributes(self.attrs, self.ns_map),
children=ParserUtils.fetch_any_children(self.position, objects),
)
objects.append((self.var.qname, obj))
else:
var = self.var
ns_map = self.ns_map
datatype = ParserUtils.data_type(self.attrs, self.ns_map)
obj = ParserUtils.parse_value(text, [datatype.local], var.default, ns_map)

if var.derived:
obj = DerivedElement(qname=qname, value=obj)

objects.append((qname, obj))

if self.mixed:
tail = ParserUtils.normalize_content(tail)
if tail:
objects.append((None, tail))

return True


@dataclass
class WildcardNode(XmlNode):
Expand Down
Loading

0 comments on commit 09e38cb

Please sign in to comment.