diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 268140cc2..c0f54f6fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,12 +2,12 @@ exclude: tests/fixtures|docs/examples repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.9.0 + rev: v2.10.0 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/asottile/reorder_python_imports - rev: v2.3.6 + rev: v2.4.0 hooks: - id: reorder-python-imports - repo: https://github.com/ambv/black diff --git a/tests/fixtures/defxmlschema/chapter12.json b/tests/fixtures/defxmlschema/chapter12.json index 357f1313c..248d235b9 100644 --- a/tests/fixtures/defxmlschema/chapter12.json +++ b/tests/fixtures/defxmlschema/chapter12.json @@ -48,7 +48,8 @@ "other_attributes": { "{http://example.org/oth}custom": "12" } - } + }, + "substituted": false } ] } \ No newline at end of file diff --git a/tests/fixtures/defxmlschema/chapter17.json b/tests/fixtures/defxmlschema/chapter17.json index 857e12444..68b5372dd 100644 --- a/tests/fixtures/defxmlschema/chapter17.json +++ b/tests/fixtures/defxmlschema/chapter17.json @@ -22,7 +22,8 @@ "quantity": 1, "color": null, "number": 563 - } + }, + "substituted": false } ] }, diff --git a/tests/formats/dataclass/parsers/test_json.py b/tests/formats/dataclass/parsers/test_json.py index 34a2e686b..6bcfff75f 100644 --- a/tests/formats/dataclass/parsers/test_json.py +++ b/tests/formats/dataclass/parsers/test_json.py @@ -167,10 +167,11 @@ def test_bind_choice_generic_with_derived(self): XmlVar(element=True, name="b", qname="b", types=[float]), ], ) + data = {"qname": "a", "value": 1, "substituted": True} self.assertEqual( - DerivedElement(qname="a", value=1), - self.parser.bind_choice({"qname": "a", "value": 1}, var), + DerivedElement(qname="a", value=1, substituted=True), + self.parser.bind_choice(data, var), ) def test_bind_choice_generic_with_wildcard(self): @@ -199,3 +200,40 @@ def test_bind_choice_generic_with_unknown_qname(self): "XmlElements undefined choice: `compound` for qname `foo`", str(cm.exception), ) + + def test_bind_wildcard_with_any_element(self): + var = XmlVar( + wildcard=True, + name="any_element", + qname="any_element", + types=[object], + ) + + self.assertEqual( + AnyElement(qname="a", text="1"), + self.parser.bind_value(var, {"qname": "a", "text": 1}), + ) + + def test_bind_wildcard_with_derived_element(self): + var = XmlVar( + any_type=True, + name="a", + qname="a", + types=[object], + ) + actual = DerivedElement(qname="a", value=Books(book=[]), substituted=True) + data = {"qname": "a", "value": {"book": []}, "substituted": True} + + self.assertEqual(actual, self.parser.bind_value(var, data)) + + def test_bind_wildcard_with_no_matching_value(self): + var = XmlVar( + any_type=True, + name="a", + qname="a", + types=[object], + ) + + data = {"test_bind_wildcard_with_no_matching_value": False} + self.assertEqual(data, self.parser.bind_value(var, data)) + self.assertEqual(1, self.parser.bind_value(var, 1)) diff --git a/tests/formats/dataclass/parsers/test_nodes.py b/tests/formats/dataclass/parsers/test_nodes.py index 20b2a5c3e..1badf55f1 100644 --- a/tests/formats/dataclass/parsers/test_nodes.py +++ b/tests/formats/dataclass/parsers/test_nodes.py @@ -20,7 +20,6 @@ from xsdata.formats.dataclass.models.generics import DerivedElement from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.parsers.mixins import XmlHandler -from xsdata.formats.dataclass.parsers.nodes import AnyTypeNode from xsdata.formats.dataclass.parsers.nodes import ElementNode from xsdata.formats.dataclass.parsers.nodes import NodeParser from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode @@ -392,29 +391,42 @@ def test_build_node_with_any_type_var_with_matching_xsi_type(self): self.assertEqual(ns_map, actual.ns_map) self.assertFalse(actual.mixed) + def test_build_node_with_any_type_var_with_datatype(self): + var = XmlVar(element=True, name="a", qname="a", types=[object], any_type=True) + attrs = {QNames.XSI_TYPE: "xs:hexBinary"} + ns_map = {Namespace.XS.prefix: Namespace.XS.uri} + actual = self.node.build_node(var, attrs, ns_map, 10) + + self.assertIsInstance(actual, PrimitiveNode) + self.assertEqual(ns_map, actual.ns_map) + self.assertEqual([DataType.HEX_BINARY.type], actual.types) + self.assertIsNone(actual.default) + self.assertFalse(actual.tokens) + self.assertEqual(DataType.HEX_BINARY.format, actual.format) + self.assertEqual(var.derived, actual.derived) + self.assertEqual(DataType.HEX_BINARY.wrapper, actual.wrapper) + def test_build_node_with_any_type_var_with_no_matching_xsi_type(self): var = XmlVar(element=True, name="a", qname="a", types=[object], any_type=True) attrs = {QNames.XSI_TYPE: "noMatch"} actual = self.node.build_node(var, attrs, {}, 10) - self.assertIsInstance(actual, AnyTypeNode) + self.assertIsInstance(actual, WildcardNode) self.assertEqual(10, actual.position) self.assertEqual(var, actual.var) self.assertEqual(attrs, actual.attrs) self.assertEqual({}, actual.ns_map) - self.assertFalse(actual.mixed) def test_build_node_with_any_type_var_with_no_xsi_type(self): var = XmlVar(element=True, name="a", qname="a", types=[object], any_type=True) attrs = {} actual = self.node.build_node(var, attrs, {}, 10) - self.assertIsInstance(actual, AnyTypeNode) + self.assertIsInstance(actual, WildcardNode) self.assertEqual(10, actual.position) self.assertEqual(var, actual.var) self.assertEqual(attrs, actual.attrs) self.assertEqual({}, actual.ns_map) - self.assertFalse(actual.mixed) def test_build_node_with_wildcard_var(self): var = XmlVar(wildcard=True, name="a", qname="a", types=[], dataclass=False) @@ -432,95 +444,13 @@ def test_build_node_with_primitive_var(self): actual = self.node.build_node(var, attrs, ns_map, 10) self.assertIsInstance(actual, PrimitiveNode) - self.assertEqual(var, actual.var) self.assertEqual(ns_map, actual.ns_map) - - -class AnyTypeNodeTests(TestCase): - def setUp(self) -> None: - self.var = XmlVar(element=True, name="a", qname="a", types=[object]) - self.node = AnyTypeNode(position=0, var=self.var, attrs={}, ns_map={}) - - def test_child(self): - self.assertFalse(self.node.has_children) - - attrs = {"a": 1} - ns_map = {"ns0": "b"} - actual = self.node.child("foo", attrs, ns_map, 10) - - self.assertIsInstance(actual, WildcardNode) - self.assertEqual(10, actual.position) - self.assertEqual(self.var, actual.var) - self.assertEqual(attrs, actual.attrs) - self.assertEqual(ns_map, actual.ns_map) - self.assertTrue(self.node.has_children) - - def test_bind_with_children(self): - text = "\n " - tail = "bar" - generic = AnyElement( - qname="a", - text=None, - tail="bar", - attributes={}, - children=[1, 2, 3], - ) - - objects = [("a", 1), ("b", 2), ("c", 3)] - - self.node.has_children = True - self.assertTrue(self.node.bind("a", text, tail, objects)) - self.assertEqual(self.var.qname, objects[-1][0]) - self.assertEqual(generic, objects[-1][1]) - - def test_bind_with_simple_type(self): - objects = [] - - self.node.attrs[QNames.XSI_TYPE] = "xs:float" - self.node.ns_map["xs"] = Namespace.XS.uri - - self.assertTrue(self.node.bind("a", "10", None, objects)) - self.assertEqual(self.var.qname, objects[-1][0]) - self.assertEqual(10.0, objects[-1][1]) - - def test_bind_with_simple_type_that_has_wrapper_class(self): - objects = [] - - self.node.attrs[QNames.XSI_TYPE] = "xs:hexBinary" - self.node.ns_map["xs"] = Namespace.XS.uri - - self.assertTrue(self.node.bind("a", "4368726973", None, objects)) - self.assertEqual(self.var.qname, objects[-1][0]) - self.assertEqual(b"Chris", objects[-1][1]) - self.assertIsInstance(objects[-1][1], XmlHexBinary) - - def test_bind_with_simple_type_derived(self): - objects = [] - - self.node.var = XmlVar( - element=True, name="a", qname="a", types=[object], derived=True - ) - self.node.attrs[QNames.XSI_TYPE] = str(DataType.FLOAT) - - self.assertTrue(self.node.bind("a", "10", None, objects)) - self.assertEqual(self.var.qname, objects[-1][0]) - self.assertEqual(DerivedElement(qname="a", value=10.0), objects[-1][1]) - - def test_bind_with_simple_type_with_mixed_content(self): - objects = [] - - self.node.mixed = True - self.node.attrs[QNames.XSI_TYPE] = str(DataType.FLOAT) - - self.assertTrue(self.node.bind("a", "10", "pieces", objects)) - self.assertEqual(self.var.qname, objects[-2][0]) - self.assertEqual(10.0, objects[-2][1]) - self.assertIsNone(objects[-1][0]) - self.assertEqual("pieces", objects[-1][1]) - - self.assertTrue(self.node.bind("a", "10", "\n", objects)) - self.assertEqual(self.var.qname, objects[-1][0]) - self.assertEqual(10.0, objects[-1][1]) + self.assertEqual(actual.types, var.types) + self.assertEqual(actual.tokens, var.tokens) + self.assertEqual(actual.format, var.format) + self.assertEqual(actual.derived, var.derived) + self.assertEqual(actual.default, var.default) + self.assertIsNone(actual.wrapper) class WildcardNodeTests(TestCase): @@ -546,6 +476,21 @@ def test_bind(self): self.assertEqual(var.qname, objects[-1][0]) self.assertEqual(generic, objects[-1][1]) + # Preserve whitespace if no children + node.position = 1 + node.bind("foo", text, tail, objects) + generic.text = text + generic.children = [] + self.assertEqual(generic, objects[-1][1]) + + # Not a wildcard, no tail/attrs/children skip wrapper + tail = None + text = "1" + node.attrs = {} + node.position = 2 + node.bind("a", text, tail, objects) + self.assertEqual("1", objects[-1][1]) + def test_child(self): attrs = {"id": "1"} ns_map = {"ns0": "xsdata"} @@ -648,7 +593,7 @@ def test_bind(self, mock_parse_value): mock_parse_value.return_value = 13 var = XmlVar(text=True, name="foo", qname="foo", types=[int], format="Nope") ns_map = {"foo": "bar"} - node = PrimitiveNode(var=var, ns_map=ns_map) + node = PrimitiveNode.from_var(var, ns_map) objects = [] self.assertTrue(node.bind("foo", "13", "Impossible", objects)) @@ -658,17 +603,27 @@ def test_bind(self, mock_parse_value): "13", var.types, var.default, ns_map, var.tokens, var.format ) - def test_bind_derived_var(self): + def test_bind_derived_mode(self): var = XmlVar(text=True, name="foo", qname="foo", types=[int], derived=True) ns_map = {"foo": "bar"} - node = PrimitiveNode(var=var, ns_map=ns_map) + node = PrimitiveNode.from_var(var, ns_map) objects = [] self.assertTrue(node.bind("foo", "13", "Impossible", objects)) self.assertEqual(DerivedElement("foo", 13), objects[-1][1]) + def test_bind_wrapper_mode(self): + datatype = DataType.HEX_BINARY + ns_map = {"foo": "bar"} + node = PrimitiveNode.from_datatype(datatype, True, ns_map) + objects = [] + + self.assertTrue(node.bind("foo", "13", "Impossible", objects)) + self.assertEqual(DerivedElement("foo", XmlHexBinary(b"\x13")), objects[-1][1]) + def test_child(self): - node = PrimitiveNode(var=XmlVar(text=True, name="foo", qname="foo"), ns_map={}) + var = XmlVar(text=True, name="foo", qname="foo") + node = PrimitiveNode.from_var(var, {}) with self.assertRaises(XmlContextError): node.child("foo", {}, {}, 0) @@ -809,7 +764,7 @@ def test_end(self, mock_assemble): objects = [("q", "result")] queue = [] var = XmlVar(text=True, name="foo", qname="foo") - queue.append(PrimitiveNode(var=var, ns_map={})) + queue.append(PrimitiveNode.from_var(var, ns_map={})) result = parser.end(queue, objects, "author", "foobar", None) self.assertEqual("result", result) diff --git a/tests/formats/dataclass/parsers/test_xml.py b/tests/formats/dataclass/parsers/test_xml.py index 62670c6f8..e3a0ac473 100644 --- a/tests/formats/dataclass/parsers/test_xml.py +++ b/tests/formats/dataclass/parsers/test_xml.py @@ -32,7 +32,7 @@ def test_end(self, mock_emit_event): objects = [] queue = [] var = XmlVar(text=True, name="foo", qname="foo", types=[bool]) - queue.append(PrimitiveNode(var=var, ns_map={})) + queue.append(PrimitiveNode.from_var(var, {})) result = self.parser.end(queue, objects, "enabled", "true", None) self.assertTrue(result) diff --git a/tests/formats/dataclass/serializers/test_xml.py b/tests/formats/dataclass/serializers/test_xml.py index def3fe2f1..41e1ff405 100644 --- a/tests/formats/dataclass/serializers/test_xml.py +++ b/tests/formats/dataclass/serializers/test_xml.py @@ -194,7 +194,7 @@ def test_write_any_type_with_primitive_element(self): self.assertIsInstance(result, Generator) self.assertEqual(expected, list(result)) - def test_write_any_type_with_generic_object(self): + def test_write_any_type_with_any_element(self): var = XmlVar(wildcard=True, qname="a", name="a") value = AnyElement( qname="a", @@ -218,6 +218,37 @@ def test_write_any_type_with_generic_object(self): self.assertIsInstance(result, Generator) self.assertEqual(expected, list(result)) + def test_write_any_type_with_derived_element_primitive(self): + var = XmlVar(wildcard=True, qname="a", name="a") + value = DerivedElement(qname="a", value=1) + expected = [ + (XmlWriterEvent.START, "a"), + (XmlWriterEvent.ATTR, QNames.XSI_TYPE, QName(str(DataType.SHORT))), + (XmlWriterEvent.DATA, 1), + (XmlWriterEvent.END, "a"), + ] + + result = self.serializer.write_value(value, var, "xsdata") + self.assertIsInstance(result, Generator) + self.assertEqual(expected, list(result)) + + def test_write_any_type_with_derived_element_dataclass(self): + var = XmlVar(wildcard=True, qname="a", name="a") + value = DerivedElement(qname="a", value=BookForm(title="def"), substituted=True) + expected = [ + (XmlWriterEvent.START, "a"), + (XmlWriterEvent.ATTR, "lang", "en"), + (XmlWriterEvent.ATTR, QNames.XSI_TYPE, QName("{urn:books}BookForm")), + (XmlWriterEvent.START, "title"), + (XmlWriterEvent.DATA, "def"), + (XmlWriterEvent.END, "title"), + (XmlWriterEvent.END, "a"), + ] + + result = self.serializer.write_value(value, var, "xsdata") + self.assertIsInstance(result, Generator) + self.assertEqual(expected, list(result)) + def test_write_xsi_type(self): var = XmlVar( element=True, qname="a", name="a", dataclass=True, types=[BookForm] diff --git a/tests/formats/dataclass/test_context.py b/tests/formats/dataclass/test_context.py index 42344ee77..fd9d43ede 100644 --- a/tests/formats/dataclass/test_context.py +++ b/tests/formats/dataclass/test_context.py @@ -10,6 +10,7 @@ from typing import Iterator from typing import List from typing import Type +from typing import TypeVar from typing import Union from unittest import mock from unittest import TestCase @@ -397,6 +398,29 @@ def test_get_type_hints_with_choices(self): ) self.assertEqual(expected, list(actual)[0]) + def test_get_type_hints_with_typevars(self): + + A = TypeVar("A", str, int) + B = TypeVar("B", bound=object) + + foo = make_dataclass("Foo", [("a", A), ("b", B), ("c", List[B])]) + + actual = self.ctx.get_type_hints(foo, None, return_input, return_input) + expected = [ + XmlVar(name="a", qname="a", element=True, types=[int, str]), + XmlVar(name="b", qname="b", any_type=True, element=True, types=[object]), + XmlVar( + name="c", + qname="c", + any_type=True, + list_element=True, + element=True, + types=[object], + ), + ] + + self.assertEqual(expected, list(actual)) + def test_get_type_hints_with_no_dataclass(self): with self.assertRaises(TypeError): list(self.ctx.get_type_hints(self.__class__, None)) diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index c1dbeb9cd..96a8ca46b 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -16,6 +16,7 @@ from typing import Optional from typing import Set from typing import Type +from typing import TypeVar from xsdata.exceptions import XmlContextError from xsdata.formats.bindings import T @@ -384,18 +385,28 @@ def real_types(cls, type_hint: Any) -> List: :param type_hint: A typing declaration """ - types = [] + type_vars = [] if type_hint is Dict: - types.append(type_hint) + type_vars.append(type_hint) elif hasattr(type_hint, "__origin__"): while len(type_hint.__args__) == 1 and hasattr( type_hint.__args__[0], "__origin__" ): type_hint = type_hint.__args__[0] - types = [x for x in type_hint.__args__ if x is not None.__class__] + type_vars = [x for x in type_hint.__args__ if x is not None.__class__] else: - types.append(type_hint) + type_vars.append(type_hint) + + types = [] + for type_var in type_vars: + if isinstance(type_var, TypeVar): + if type_var.__bound__: + types.append(type_var.__bound__) + else: + types.extend(type_var.__constraints__) + else: + types.append(type_var) return converter.sort_types(types) diff --git a/xsdata/formats/dataclass/models/generics.py b/xsdata/formats/dataclass/models/generics.py index ce76b762e..5cad0fc68 100644 --- a/xsdata/formats/dataclass/models/generics.py +++ b/xsdata/formats/dataclass/models/generics.py @@ -8,7 +8,7 @@ from xsdata.formats.dataclass.models.elements import XmlType -T = TypeVar("T") +T = TypeVar("T", bound=object) @dataclass @@ -42,7 +42,9 @@ class DerivedElement(Generic[T]): :param qname: The element's qualified name :param value: The wrapped value + :param substituted: Specify whether the value is a type substitution """ qname: str value: T + substituted: bool = False diff --git a/xsdata/formats/dataclass/parsers/json.py b/xsdata/formats/dataclass/parsers/json.py index 084f22beb..3cf147c77 100644 --- a/xsdata/formats/dataclass/parsers/json.py +++ b/xsdata/formats/dataclass/parsers/json.py @@ -4,6 +4,7 @@ import warnings from dataclasses import dataclass from dataclasses import field +from dataclasses import fields from dataclasses import is_dataclass from typing import Any from typing import Dict @@ -23,6 +24,10 @@ from xsdata.utils.constants import EMPTY_MAP +ANY_KEYS = {f.name for f in fields(AnyElement)} +DERIVED_KEYS = {f.name for f in fields(DerivedElement)} + + @dataclass class JsonParser(AbstractParser): """ @@ -66,12 +71,12 @@ def bind_value(self, var: XmlVar, value: Any) -> Any: if var.clazz: return self.bind_dataclass(value, var.clazz) - if var.wildcard: - return self.bind_wildcard(value) - if var.elements: return self.bind_choice(value, var) + if var.wildcard or var.any_type: + return self.bind_wildcard(value) + return self.parse_value(value, var.types, var.default, var.tokens, var.format) def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T: @@ -126,7 +131,17 @@ def bind_type_union(self, value: Any, var: XmlVar) -> Any: def bind_wildcard(self, value: Any) -> Any: """Bind data to a wildcard model.""" if isinstance(value, Dict): - return self.bind_dataclass(value, AnyElement) + keys = set(value.keys()) + + if not (keys - ANY_KEYS): + return self.bind_dataclass(value, AnyElement) + + if not (keys - DERIVED_KEYS): + return self.bind_dataclass(value, DerivedElement) + + clazz: Optional[Type] = self.context.find_type_by_fields(keys) + if clazz: + return self.bind_dataclass(value, clazz) return value @@ -171,7 +186,10 @@ def bind_choice_generic(self, value: Dict, var: XmlVar) -> Any: ) if "value" in value: - return DerivedElement(qname, self.bind_value(choice, value["value"])) + obj = self.bind_value(choice, value["value"]) + substituted = value.get("substituted", False) + + return DerivedElement(qname=qname, value=obj, substituted=substituted) return self.bind_dataclass(value, AnyElement) diff --git a/xsdata/formats/dataclass/parsers/nodes.py b/xsdata/formats/dataclass/parsers/nodes.py index 3ae599220..0a51fdb3c 100644 --- a/xsdata/formats/dataclass/parsers/nodes.py +++ b/xsdata/formats/dataclass/parsers/nodes.py @@ -27,6 +27,7 @@ from xsdata.formats.dataclass.parsers.mixins import XmlHandler from xsdata.formats.dataclass.parsers.mixins import XmlNode from xsdata.formats.dataclass.parsers.utils import ParserUtils +from xsdata.models.enums import DataType from xsdata.models.enums import EventType Parsed = Tuple[Optional[str], Any] @@ -57,6 +58,7 @@ class ElementNode(XmlNode): position: int mixed: bool = False derived: bool = False + substituted: bool = False assigned: Set = field(default_factory=set) def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: @@ -79,7 +81,7 @@ def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: obj = self.meta.clazz(**params) if self.derived: - obj = DerivedElement(qname=qname, value=obj) + obj = DerivedElement(qname=qname, value=obj, substituted=self.substituted) objects.append((qname, obj)) @@ -118,27 +120,40 @@ def build_node( position=position, ) + xsi_type = ParserUtils.xsi_type(attrs, ns_map) + if var.clazz: return self.build_element_node( - var.clazz, attrs, ns_map, position, var.derived + var.clazz, + attrs, + ns_map, + position, + var.derived, + xsi_type, ) - if var.any_type: - node = self.build_element_node(None, attrs, ns_map, position, var.derived) - if not node: - node = AnyTypeNode( - var=var, - attrs=attrs, - ns_map=ns_map, - position=position, - mixed=self.meta.has_var(mode=FindMode.MIXED_CONTENT), - ) - return node + if not var.any_type and not var.wildcard: + return PrimitiveNode.from_var(var, ns_map) - if var.wildcard: - return WildcardNode(var=var, attrs=attrs, ns_map=ns_map, position=position) + datatype = DataType.from_qname(xsi_type) if xsi_type else None + derived = var.derived or var.wildcard + if datatype: + return PrimitiveNode.from_datatype(datatype, derived, ns_map) - return PrimitiveNode(var=var, ns_map=ns_map) + node = None + clazz = None + if xsi_type: + clazz = self.context.find_type(xsi_type) + + if clazz: + node = self.build_element_node( + clazz, attrs, ns_map, position, derived, xsi_type + ) + + if node: + return node + + return WildcardNode(var=var, attrs=attrs, ns_map=ns_map, position=position) def fetch_vars(self, qname: str) -> Iterator[Tuple[Any, XmlVar]]: for mode in FIND_MODES: @@ -158,28 +173,17 @@ def fetch_vars(self, qname: str) -> Iterator[Tuple[Any, XmlVar]]: def build_element_node( self, - clazz: Optional[Type], + clazz: Type, attrs: Dict, ns_map: Dict, position: int, derived: bool, + xsi_type: Optional[str] = None, ) -> Optional[XmlNode]: - xsi_type = ParserUtils.xsi_type(attrs, ns_map) - - if clazz is None: - if not xsi_type: - return None - - clazz = self.context.find_type(xsi_type) - xsi_type = None - - if clazz is None: - return None - is_nillable = ParserUtils.is_nillable(attrs) meta = self.context.fetch(clazz, self.meta.namespace, xsi_type) - if not is_nillable and meta.nillable: + if not meta or (meta.nillable and not ParserUtils.is_nillable(attrs)): return None return ElementNode( @@ -190,76 +194,11 @@ def build_element_node( context=self.context, position=position, derived=derived, + substituted=xsi_type is not None, mixed=self.meta.has_var(mode=FindMode.MIXED_CONTENT), ) -@dataclass -class AnyTypeNode(XmlNode): - """ - XmlNode for elements with an inline datatype declaration through the - xsi:type attribute. - - :param var: Class field xml var instance - :param attrs: Key-value attribute mapping - :param ns_map: Namespace prefix-URI map - :param position: The node position of objects cache - :param mixed: Specify if the parent node supports mixed content - :ivar has_children: Specifies whether the node has encounter any - children so far - """ - - var: XmlVar - attrs: Dict - ns_map: Dict - position: int - mixed: bool = False - has_children: bool = field(init=False, default=False) - - def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> "XmlNode": - self.has_children = True - return WildcardNode(position=position, var=self.var, attrs=attrs, ns_map=ns_map) - - def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: - obj: Any = None - if self.has_children: - obj = AnyElement( - qname=qname, - text=ParserUtils.normalize_content(text), - tail=ParserUtils.normalize_content(tail), - attributes=ParserUtils.parse_any_attributes(self.attrs, self.ns_map), - children=ParserUtils.fetch_any_children(self.position, objects), - ) - objects.append((self.var.qname, obj)) - else: - var = self.var - ns_map = self.ns_map - datatype = ParserUtils.data_type(self.attrs, self.ns_map) - obj = ParserUtils.parse_value( - text, - [datatype.type], - var.default, - ns_map, - var.tokens, - datatype.format, - ) - - if datatype.wrapper: - obj = datatype.wrapper(obj) - - if var.derived: - obj = DerivedElement(qname=qname, value=obj) - - objects.append((qname, obj)) - - if self.mixed: - tail = ParserUtils.normalize_content(tail) - if tail: - objects.append((None, tail)) - - return True - - @dataclass class WildcardNode(XmlNode): """ @@ -281,14 +220,23 @@ class WildcardNode(XmlNode): position: int def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: - obj = AnyElement( - qname=qname, - text=ParserUtils.normalize_content(text), - tail=ParserUtils.normalize_content(tail), - attributes=ParserUtils.parse_any_attributes(self.attrs, self.ns_map), - children=ParserUtils.fetch_any_children(self.position, objects), - ) - objects.append((self.var.qname, obj)) + children = ParserUtils.fetch_any_children(self.position, objects) + attributes = ParserUtils.parse_any_attributes(self.attrs, self.ns_map) + derived = self.var.derived or qname != self.var.qname + text = ParserUtils.normalize_content(text) if children else text + tail = ParserUtils.normalize_content(tail) + + if tail or attributes or children or self.var.wildcard or derived: + obj = AnyElement( + qname=qname, + text=text, + tail=tail, + attributes=attributes, + children=children, + ) + objects.append((self.var.qname, obj)) + else: + objects.append((self.var.qname, text)) return True @@ -377,17 +325,28 @@ class PrimitiveNode(XmlNode): :param ns_map: Namespace prefix-URI map """ - var: XmlVar + types: List[Type] + default: Any + tokens: bool + format: Optional[str] + derived: bool + wrapper: Optional[Type] ns_map: Dict def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: - var = self.var - ns_map = self.ns_map obj = ParserUtils.parse_value( - text, var.types, var.default, ns_map, var.tokens, var.format + text, + self.types, + self.default, + self.ns_map, + self.tokens, + self.format, ) - if var.derived: + if self.wrapper: + obj = self.wrapper(obj) + + if self.derived: obj = DerivedElement(qname=qname, value=obj) objects.append((qname, obj)) @@ -396,6 +355,32 @@ def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: raise XmlContextError("Primitive node doesn't support child nodes!") + @classmethod + def from_var(cls, var: XmlVar, ns_map: Dict) -> "PrimitiveNode": + return cls( + types=var.types, + default=var.default, + tokens=var.tokens, + format=var.format, + derived=var.derived, + wrapper=None, + ns_map=ns_map, + ) + + @classmethod + def from_datatype( + cls, datatype: DataType, derived: bool, ns_map: Dict + ) -> "PrimitiveNode": + return cls( + types=[datatype.type], + default=None, + tokens=False, + format=datatype.format, + derived=derived, + wrapper=datatype.wrapper, + ns_map=ns_map, + ) + @dataclass class SkipNode(XmlNode): diff --git a/xsdata/formats/dataclass/serializers/xml.py b/xsdata/formats/dataclass/serializers/xml.py index 3944cfbc3..4bd3c0de5 100644 --- a/xsdata/formats/dataclass/serializers/xml.py +++ b/xsdata/formats/dataclass/serializers/xml.py @@ -188,7 +188,7 @@ def write_any_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generat if isinstance(value, AnyElement): yield from self.write_wildcard(value, var, namespace) elif isinstance(value, DerivedElement): - yield from self.write_dataclass(value.value, namespace, qname=value.qname) + yield from self.write_derived_element(value, var, namespace) elif is_dataclass(value): yield from self.write_xsi_type(value, var, namespace) elif var.element: @@ -196,6 +196,26 @@ def write_any_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generat else: yield from self.write_data(value, var, namespace) + def write_derived_element( + self, value: DerivedElement, var: XmlVar, namespace: NoneStr + ) -> Generator: + if is_dataclass(value.value): + xsi_type = None + if value.substituted: + meta = self.context.build(value.value.__class__) + xsi_type = QName(meta.source_qname) + + yield from self.write_dataclass( + value.value, namespace, qname=value.qname, xsi_type=xsi_type + ) + else: + datatype = DataType.from_value(value.value) + + yield XmlWriterEvent.START, value.qname + yield XmlWriterEvent.ATTR, QNames.XSI_TYPE, QName(str(datatype)) + yield XmlWriterEvent.DATA, value.value + yield XmlWriterEvent.END, value.qname + def write_wildcard( self, value: AnyElement, var: XmlVar, namespace: NoneStr ) -> Generator: