Skip to content

Commit

Permalink
Resolve #306 Lookup xsi type for xs:anyType derived elements
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Nov 5, 2020
1 parent fbb9a9b commit 7b6c022
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.3.5
rev: v2.3.6
hooks:
- id: reorder-python-imports
- repo: https://github.com/ambv/black
Expand Down
114 changes: 113 additions & 1 deletion tests/formats/dataclass/parsers/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.mixins import XmlHandler
from xsdata.formats.dataclass.parsers.nodes import AnyTypeNode
from xsdata.formats.dataclass.parsers.nodes import ElementNode
from xsdata.formats.dataclass.parsers.nodes import NodeParser
from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode
from xsdata.formats.dataclass.parsers.nodes import SkipNode
from xsdata.formats.dataclass.parsers.nodes import UnionNode
from xsdata.formats.dataclass.parsers.nodes import WildcardNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import DataType
from xsdata.models.enums import QNames
from xsdata.models.mixins import attribute
from xsdata.models.mixins import element
Expand Down Expand Up @@ -355,7 +357,43 @@ def test_build_node_with_dataclass_var_validates_nillable(self, mock_ctx_fetch):
attrs = {QNames.XSI_NIL: "false"}
self.assertIsNone(self.node.build_node(var, attrs, ns_map, 10))

def test_build_node_with_any_type_var(self):
def test_build_node_with_any_type_var_with_matching_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])

actual = self.node.build_node(var, {QNames.XSI_TYPE: "Foo"}, {}, 10)

self.assertIsInstance(actual, ElementNode)
self.assertEqual(10, actual.position)
self.assertEqual(self.context.build(Foo), actual.meta)
self.assertEqual({QNames.XSI_TYPE: "Foo"}, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_any_type_var_with_no_matching_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])
attrs = {QNames.XSI_TYPE: "bar"}
actual = self.node.build_node(var, attrs, {}, 10)

self.assertIsInstance(actual, AnyTypeNode)
self.assertEqual(10, actual.position)
self.assertEqual(var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_any_type_var_with_no_xsi_type(self):
var = XmlElement(name="a", qname="a", types=[object])
attrs = {}
actual = self.node.build_node(var, attrs, {}, 10)

self.assertIsInstance(actual, AnyTypeNode)
self.assertEqual(10, actual.position)
self.assertEqual(var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual({}, actual.ns_map)
self.assertFalse(actual.mixed)

def test_build_node_with_wildcard_var(self):
var = XmlWildcard(name="a", qname="a", types=[], dataclass=False)

actual = self.node.build_node(var, {}, {}, 10)
Expand All @@ -375,6 +413,80 @@ def test_build_node_with_primitive_var(self):
self.assertEqual(ns_map, actual.ns_map)


class AnyTypeNodeTests(TestCase):
def setUp(self) -> None:
self.var = XmlElement(name="a", qname="a", types=[object])
self.node = AnyTypeNode(position=0, var=self.var, attrs={}, ns_map={})

def test_child(self):
self.assertFalse(self.node.has_children)

attrs = {"a": 1}
ns_map = {"ns0": "b"}
actual = self.node.child("foo", attrs, ns_map, 10)

self.assertIsInstance(actual, WildcardNode)
self.assertEqual(10, actual.position)
self.assertEqual(self.var, actual.var)
self.assertEqual(attrs, actual.attrs)
self.assertEqual(ns_map, actual.ns_map)
self.assertTrue(self.node.has_children)

def test_bind_with_children(self):
text = "\n "
tail = "bar"
generic = AnyElement(
qname="a",
text=None,
tail="bar",
ns_map={},
attributes={},
children=[1, 2, 3],
)

objects = [("a", 1), ("b", 2), ("c", 3)]

self.node.has_children = True
self.assertTrue(self.node.bind("a", text, tail, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(generic, objects[-1][1])

def test_bind_with_simple_type(self):
objects = []

self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", None, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(10.0, objects[-1][1])

def test_bind_with_simple_type_derived(self):
objects = []

self.node.var = XmlElement(name="a", qname="a", types=[object], derived=True)
self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", None, objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(DerivedElement(qname="a", value=10.0), objects[-1][1])

def test_bind_with_simple_type_with_mixed_content(self):
objects = []

self.node.mixed = True
self.node.attrs[QNames.XSI_TYPE] = DataType.FLOAT.qname

self.assertTrue(self.node.bind("a", "10", "pieces", objects))
self.assertEqual(self.var.qname, objects[-2][0])
self.assertEqual(10.0, objects[-2][1])
self.assertIsNone(objects[-1][0])
self.assertEqual("pieces", objects[-1][1])

self.assertTrue(self.node.bind("a", "10", "\n", objects))
self.assertEqual(self.var.qname, objects[-1][0])
self.assertEqual(10.0, objects[-1][1])


class WildcardNodeTests(TestCase):
def test_bind(self):
text = "\n "
Expand Down
16 changes: 15 additions & 1 deletion tests/formats/dataclass/parsers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from xsdata.formats.dataclass.models.generics import AnyElement
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import DataType
from xsdata.models.enums import Namespace
from xsdata.models.enums import QNames

Expand All @@ -37,6 +38,18 @@ def test_xsi_type(self):
attrs = {QNames.XSI_TYPE: "bar:foo"}
self.assertEqual("{xsdata}foo", ParserUtils.xsi_type(attrs, ns_map))

def test_data_type(self):
ns_map = {"bar": "xsdata"}
attrs = {}
self.assertEqual(DataType.STRING, ParserUtils.data_type(attrs, ns_map))

ns_map = {"xs": Namespace.XS.uri}
attrs = {QNames.XSI_TYPE: "xs:foo"}
self.assertEqual(DataType.STRING, ParserUtils.data_type(attrs, ns_map))

attrs = {QNames.XSI_TYPE: "xs:float"}
self.assertEqual(DataType.FLOAT, ParserUtils.data_type(attrs, ns_map))

@mock.patch.object(ConverterAdapter, "from_string", return_value=2)
def test_parse_value(self, mock_from_string):
self.assertEqual(1, ParserUtils.parse_value(None, [int], 1))
Expand Down Expand Up @@ -122,13 +135,14 @@ def test_bind_mixed_objects(self):
("b", None),
("d", data_class),
("foo", generic),
(None, "foo"),
]

var = XmlWildcard(name="foo", qname="{any}foo")
params = {}
ParserUtils.bind_mixed_objects(params, var, 1, objects)

expected = {var.name: [AnyElement(qname="b", text=""), derived, generic]}
expected = {var.name: [AnyElement(qname="b", text=""), derived, generic, "foo"]}
self.assertEqual(expected, params)

def test_fetch_any_children(self):
Expand Down
38 changes: 35 additions & 3 deletions tests/formats/dataclass/serializers/test_xml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from dataclasses import field
from dataclasses import make_dataclass
from decimal import Decimal
from typing import Dict
from typing import Generator
from typing import List
Expand All @@ -20,6 +21,7 @@
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.serializers import XmlSerializer
from xsdata.formats.dataclass.serializers.mixins import XmlWriterEvent
from xsdata.models.enums import DataType
from xsdata.models.enums import QNames


Expand Down Expand Up @@ -117,7 +119,6 @@ def test_write_mixed_content(self):
(XmlWriterEvent.START, "br"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.END, "br"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.START, "{xsdata}BookForm"),
(XmlWriterEvent.ATTR, "id", "123"),
(XmlWriterEvent.ATTR, "lang", "en"),
Expand Down Expand Up @@ -202,7 +203,6 @@ def test_write_any_type_with_generic_object(self):
(XmlWriterEvent.ATTR, "e", 2),
(XmlWriterEvent.DATA, "b"),
(XmlWriterEvent.DATA, "g"),
(XmlWriterEvent.DATA, None),
(XmlWriterEvent.DATA, "h"),
(XmlWriterEvent.END, "a"),
(XmlWriterEvent.DATA, "c"),
Expand Down Expand Up @@ -279,6 +279,31 @@ def test_write_element_with_nillable_true(self):
self.assertIsInstance(result, Generator)
self.assertEqual(expected, list(result))

def test_write_element_with_any_type_var(self):
var = XmlElement(qname="a", name="a", types=[object])
expected = [
(XmlWriterEvent.START, "a"),
(XmlWriterEvent.ATTR, QNames.XSI_TYPE, DataType.INT.qname),
(XmlWriterEvent.DATA, 123),
(XmlWriterEvent.END, "a"),
]

result = self.serializer.write_value(123, var, "xsdata")
self.assertIsInstance(result, Generator)
self.assertEqual(expected, list(result))

def test_write_element_with_any_type_var_ignore_xs_string(self):
var = XmlElement(qname="a", name="a", types=[object])
expected = [
(XmlWriterEvent.START, "a"),
(XmlWriterEvent.DATA, "123"),
(XmlWriterEvent.END, "a"),
]

result = self.serializer.write_value("123", var, "xsdata")
self.assertIsInstance(result, Generator)
self.assertEqual(expected, list(result))

def test_write_choice_with_derived_primitive_value(self):
var = XmlElements(
name="compound",
Expand Down Expand Up @@ -324,7 +349,6 @@ def test_write_choice_with_generic_object(self):
(XmlWriterEvent.START, "a"),
(XmlWriterEvent.DATA, "1"),
(XmlWriterEvent.END, "a"),
(XmlWriterEvent.DATA, None),
]

result = self.serializer.write_value(value, var, "xsdata")
Expand Down Expand Up @@ -478,3 +502,11 @@ class Meta:
obj.content.append("!")
result = self.serializer.render(obj).split("\n")
self.assertEqual("<p>Hi <b>Mr.</b><span>chris</span>!</p>", result[1])

def test_value_datatype(self):
self.assertEqual(DataType.BOOLEAN, XmlSerializer.value_datatype(True))
self.assertEqual(DataType.INT, XmlSerializer.value_datatype(1))
self.assertEqual(DataType.FLOAT, XmlSerializer.value_datatype(1.1))
self.assertEqual(DataType.DECIMAL, XmlSerializer.value_datatype(Decimal(1.1)))
self.assertEqual(DataType.QNAME, XmlSerializer.value_datatype(QName("a")))
self.assertEqual(DataType.STRING, XmlSerializer.value_datatype("a"))
11 changes: 10 additions & 1 deletion tests/formats/dataclass/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,20 @@ def test_get_type_hints_with_choices(self):
namespaces=["bar"],
derived=True,
),
XmlElement(
name="compound",
qname="{bar}o",
nillable=False,
types=[object],
namespaces=["bar"],
derived=True,
),
XmlWildcard(
name="compound",
qname="{http://www.w3.org/1999/xhtml}any",
types=[object],
namespaces=["http://www.w3.org/1999/xhtml"],
derived=False,
derived=True,
),
],
types=[object],
Expand Down Expand Up @@ -481,6 +489,7 @@ class Node:
{"name": "x", "type": List[str], "tokens": True},
{"name": "y", "type": List[int], "nillable": True},
{"name": "z", "type": List[int]},
{"name": "o", "type": object},
{
"wildcard": True,
"type": object,
Expand Down
4 changes: 0 additions & 4 deletions tests/formats/dataclass/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ def test_property_is_wildcard(self):
self.assertIsInstance(var, XmlVar)
self.assertTrue(var.is_wildcard)

def test_property_is_any_type(self):
var = XmlWildcard(name="foo", qname="foo")
self.assertTrue(var.is_any_type)

def test_matches(self):
var = XmlWildcard(name="foo", qname="foo")
self.assertTrue(var.matches("*"))
Expand Down
31 changes: 25 additions & 6 deletions xsdata/formats/dataclass/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import sys
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -57,13 +58,24 @@ def fetch(
lookup procedure needs to check and match a dataclass model to the qualified
name instead.
"""

meta = self.build(clazz, parent_ns)
subclass = None
if xsi_type and meta.source_qname != xsi_type:
subclass = self.find_subclass(clazz, xsi_type)

return self.build(subclass, parent_ns) if subclass else meta

subclass = self.find_subclass(clazz, xsi_type) if xsi_type else None
if subclass:
meta = self.build(subclass, parent_ns)
def find_type(self, clazz: Type, xsi_type: str) -> Optional[Type]:
"""Scan the clazz module for all dataclasses and match against the
given xsi type."""
module = inspect.getmodule(clazz)
for name in dir(module):
member = getattr(module, name)
if self.match_class_source_qname(member, xsi_type):
return member

return meta
return None

def find_subclass(self, clazz: Type, xsi_type: str) -> Optional[Type]:
"""
Expand Down Expand Up @@ -93,8 +105,12 @@ class as the given class.
def match_class_source_qname(self, clazz: Type, xsi_type: str) -> bool:
"""Match a given source qualified name with the given xsi type."""
if is_dataclass(clazz):
meta = self.build(clazz)
return meta.source_qname == xsi_type
meta = clazz.Meta if "Meta" in clazz.__dict__ else None
name = getattr(meta, "name", None) or self.local_name(clazz.__name__)
module = sys.modules[clazz.__module__]
source_namespace = getattr(module, "__NAMESPACE__", None)
source_qname = build_qname(source_namespace, name)
return source_qname == xsi_type

return False

Expand Down Expand Up @@ -199,6 +215,9 @@ def build_choices(
qname = build_qname(default_namespace, choice.get("name", "any"))
nillable = choice.get("nillable", False)

if xml_type == XmlType.ELEMENT and len(types) == 1 and types[0] == object:
derived = True

yield xml_clazz(
name=parent_name,
qname=qname,
Expand Down
7 changes: 3 additions & 4 deletions xsdata/formats/dataclass/models/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def is_mixed_content(self) -> bool:
def is_wildcard(self) -> bool:
return True

@property
def is_any_type(self) -> bool:
return True

def matches(self, qname: str) -> bool:
"""Match the given qname to the wildcard allowed namespaces."""

Expand Down Expand Up @@ -304,6 +300,9 @@ class XmlMeta:
def namespace(self) -> Optional[str]:
return split_qname(self.qname)[0]

def has_var(self, qname: str = "*", mode: FindMode = FindMode.ALL) -> bool:
return self.find_var(qname, mode) is not None

def find_var(
self, qname: str = "*", mode: FindMode = FindMode.ALL
) -> Optional[XmlVar]:
Expand Down
Loading

0 comments on commit 7b6c022

Please sign in to comment.