From b71f0a4598035672a1b72a63330c839d3689359d Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 19 Jun 2021 10:43:16 +0300 Subject: [PATCH 1/7] AttributeDefaultValueHandler: process attr choices --- .../handlers/test_attribute_default_value.py | 29 +++++++++++++++++++ .../handlers/attribute_default_value.py | 3 ++ 2 files changed, 32 insertions(+) diff --git a/tests/codegen/handlers/test_attribute_default_value.py b/tests/codegen/handlers/test_attribute_default_value.py index 32092724e..1fcff4afb 100644 --- a/tests/codegen/handlers/test_attribute_default_value.py +++ b/tests/codegen/handlers/test_attribute_default_value.py @@ -3,6 +3,7 @@ from xsdata.codegen.container import ClassContainer from xsdata.codegen.handlers import AttributeDefaultValueHandler from xsdata.models.config import GeneratorConfig +from xsdata.models.enums import DataType from xsdata.models.enums import Namespace from xsdata.utils.testing import AttrFactory from xsdata.utils.testing import AttrTypeFactory @@ -26,6 +27,34 @@ def test_process_attribute_with_enumeration(self): self.processor.process_attribute(target, attr) self.assertTrue(attr.fixed) + @mock.patch.object(AttributeDefaultValueHandler, "process_attribute") + def test_process_with_attr_choices(self, mock_process_attribute): + choice = AttrFactory.create( + name="attr_B_Or_attr_C", + tag="Choice", + index=0, + types=[AttrTypeFactory.native(DataType.ANY_TYPE)], + choices=[ + AttrFactory.reference("one"), + AttrFactory.reference("two"), + AttrFactory.reference("three"), + ], + ) + target = ClassFactory.create() + target.attrs.append(choice) + + self.processor.process(target) + + self.assertEqual(4, mock_process_attribute.call_count) + mock_process_attribute.assert_has_calls( + [ + mock.call(target, target.attrs[0]), + mock.call(target, target.attrs[0].choices[0]), + mock.call(target, target.attrs[0].choices[1]), + mock.call(target, target.attrs[0].choices[2]), + ] + ) + def test_process_attribute_with_optional_field(self): target = ClassFactory.create() attr = AttrFactory.create(fixed=True, default=2) diff --git a/xsdata/codegen/handlers/attribute_default_value.py b/xsdata/codegen/handlers/attribute_default_value.py index d757d79fa..19053297f 100644 --- a/xsdata/codegen/handlers/attribute_default_value.py +++ b/xsdata/codegen/handlers/attribute_default_value.py @@ -25,6 +25,9 @@ def process(self, target: Class): for attr in target.attrs: self.process_attribute(target, attr) + for choice in attr.choices: + self.process_attribute(target, choice) + def process_attribute(self, target: Class, attr: Attr): if attr.is_enumeration: From 1a2b62a28d52319e34a6d56a9a7c7265e14dbccd Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 19 Jun 2021 10:45:24 +0300 Subject: [PATCH 2/7] ClassInnersHandler: evaluate attr choices as well --- tests/codegen/handlers/test_class_inners.py | 37 +++++++++++++++++++++ xsdata/codegen/handlers/class_inners.py | 5 +++ 2 files changed, 42 insertions(+) diff --git a/tests/codegen/handlers/test_class_inners.py b/tests/codegen/handlers/test_class_inners.py index bed0533a3..983f237c9 100644 --- a/tests/codegen/handlers/test_class_inners.py +++ b/tests/codegen/handlers/test_class_inners.py @@ -1,6 +1,9 @@ +from typing import Generator + from xsdata.codegen.handlers import ClassInnersHandler from xsdata.models.enums import DataType from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import AttrTypeFactory from xsdata.utils.testing import ClassFactory from xsdata.utils.testing import ExtensionFactory from xsdata.utils.testing import FactoryTestCase @@ -71,3 +74,37 @@ def test_rename_inner(self): self.assertEqual("{xsdata}foo_Inner", outer.attrs[0].types[0].qname) self.assertEqual("{xsdata}foo_Inner", outer.inner[0].qname) + + def test_find_attr_types_with_attr_choices(self): + choices = [ + AttrFactory.create( + types=[ + AttrTypeFactory.create("bar", forward=True), + AttrTypeFactory.create("foo", forward=True), + ] + ), + AttrFactory.reference("foo"), + AttrFactory.reference("foo", forward=True), + AttrFactory.reference("bar", forward=True), + ] + choice = AttrFactory.create( + name="attr_B_Or_attr_C", + tag="Choice", + index=0, + types=[AttrTypeFactory.native(DataType.ANY_TYPE)], + choices=choices, + ) + target = ClassFactory.create() + target.attrs.append(choice) + + result = self.processor.find_attr_types(target, "foo") + self.assertIsInstance(result, Generator) + + self.assertEqual(choices[0].types[1], next(result)) + self.assertEqual(choices[2].types[0], next(result)) + self.assertIsNone(next(result, None)) + + result = self.processor.find_attr_types(target, "bar") + self.assertEqual(choices[0].types[0], next(result)) + self.assertEqual(choices[3].types[0], next(result)) + self.assertIsNone(next(result, None)) diff --git a/xsdata/codegen/handlers/class_inners.py b/xsdata/codegen/handlers/class_inners.py index 20feeffff..5f17f86cd 100644 --- a/xsdata/codegen/handlers/class_inners.py +++ b/xsdata/codegen/handlers/class_inners.py @@ -64,3 +64,8 @@ def find_attr_types(cls, target: Class, qname: str) -> Iterator[AttrType]: for attr_type in attr.types: if attr_type.forward and attr_type.qname == qname: yield attr_type + + for choice in attr.choices: + for choice_type in choice.types: + if choice_type.forward and choice_type.qname == qname: + yield choice_type From 498e3f31f0ce6b85cfbf6fa5fdec5e9e10069370 Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 19 Jun 2021 00:31:22 +0300 Subject: [PATCH 3/7] XmlMetaBuilder: locate globals per field Notes: Otherwise the forward references for compound fields are failing to evaluate correctly in case of subclasses in different modules! --- tests/fixtures/models.py | 11 +- tests/fixtures/submodels.py | 14 ++ .../formats/dataclass/models/test_builders.py | 173 +++++++++--------- tests/formats/dataclass/test_context.py | 10 +- tests/formats/dataclass/test_elements.py | 4 +- xsdata/formats/dataclass/models/builders.py | 24 ++- 6 files changed, 131 insertions(+), 105 deletions(-) create mode 100644 tests/fixtures/submodels.py diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 41d020183..b11355bad 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -3,9 +3,12 @@ from typing import Dict from typing import List from typing import Optional +from typing import Type from typing import Union from xml.etree.ElementTree import QName +from xsdata.utils.constants import return_true + @dataclass class TypeA: @@ -79,7 +82,13 @@ class ChoiceType: {"name": "int2", "type": int, "nillable": True}, {"name": "float", "type": float}, {"name": "qname", "type": QName}, - {"name": "tokens", "type": List[int], "tokens": True}, + {"name": "tokens", "type": List[int], "tokens": True, "default_factory": return_true}, + {"name": "union", "type": Type["UnionType"], "namespace": "foo"}, + {"name": "p", "type": float, "fixed": True, "default": 1.1}, + {"wildcard": True, + "type": object, + "namespace": "http://www.w3.org/1999/xhtml", + }, ), } ) diff --git a/tests/fixtures/submodels.py b/tests/fixtures/submodels.py new file mode 100644 index 000000000..568af28fb --- /dev/null +++ b/tests/fixtures/submodels.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List +from typing import Optional +from typing import Union +from xml.etree.ElementTree import QName + +from tests.fixtures.models import ChoiceType + + +@dataclass +class ChoiceTypeChild(ChoiceType): + pass \ No newline at end of file diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index 6af09f958..58885c189 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -8,15 +8,19 @@ from typing import get_type_hints from typing import Iterator from typing import List -from typing import Type from typing import Union from unittest import mock from unittest import TestCase +from xml.etree.ElementTree import QName from tests.fixtures.artists import Artist from tests.fixtures.books import BookForm +from tests.fixtures.models import ChoiceType +from tests.fixtures.models import TypeA from tests.fixtures.models import TypeB +from tests.fixtures.models import UnionType from tests.fixtures.series import Country +from tests.fixtures.submodels import ChoiceTypeChild from xsdata.exceptions import XmlContextError from xsdata.formats.dataclass.compat import class_types from xsdata.formats.dataclass.models.builders import XmlMetaBuilder @@ -103,6 +107,12 @@ def test_build_with_no_dataclass_raises_exception(self, *args): self.assertEqual(f"Type '{int}' is not a dataclass.", str(cm.exception)) + def test_build_locates_globalns_per_field(self): + actual = self.builder.build(ChoiceTypeChild, None) + self.assertEqual(1, len(actual.choices)) + self.assertEqual(9, len(actual.choices[0].elements)) + self.assertIsNone(self.builder.find_globalns(object, "foo")) + def test_target_namespace(self): class Meta: namespace = "bar" @@ -234,17 +244,18 @@ def setUp(self) -> None: ) super().setUp() + self.maxDiff = None def test_build_with_choice_field(self): - globalns = sys.modules[CompoundFieldExample.__module__].__dict__ - type_hints = get_type_hints(CompoundFieldExample) - class_field = fields(CompoundFieldExample)[0] + globalns = sys.modules[ChoiceType.__module__].__dict__ + type_hints = get_type_hints(ChoiceType) + class_field = fields(ChoiceType)[0] self.builder.parent_ns = "bar" actual = self.builder.build( 66, - "compound", - type_hints["compound"], + "choice", + type_hints["choice"], class_field.metadata, True, list, @@ -252,96 +263,110 @@ def test_build_with_choice_field(self): ) expected = XmlVarFactory.create( index=67, - xml_type=XmlType.ELEMENTS, - name="compound", - qname="compound", + name="choice", + types=(object,), list_element=True, any_type=True, default=list, + xml_type=XmlType.ELEMENTS, elements={ - "{foo}node": XmlVarFactory.create( + "{bar}a": XmlVarFactory.create( index=1, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{foo}node", + name="choice", + qname="{bar}a", + types=(TypeA,), + clazz=TypeA, list_element=True, - types=(CompoundFieldExample,), - namespaces=("foo",), - derived=False, + namespaces=("bar",), ), - "{bar}x": XmlVarFactory.create( + "{bar}b": XmlVarFactory.create( index=2, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}x", - tokens=True, + name="choice", + qname="{bar}b", + types=(TypeB,), + clazz=TypeB, list_element=True, - types=(str,), namespaces=("bar",), - derived=False, - default=return_true, - format="Nope", ), - "{bar}y": XmlVarFactory.create( + "{bar}int": XmlVarFactory.create( index=3, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}y", - nillable=True, - list_element=True, + name="choice", + qname="{bar}int", types=(int,), + list_element=True, namespaces=("bar",), - derived=False, ), - "{bar}z": XmlVarFactory.create( + "{bar}int2": XmlVarFactory.create( index=4, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}z", - nillable=False, - list_element=True, + name="choice", + qname="{bar}int2", types=(int,), - namespaces=("bar",), derived=True, + nillable=True, + list_element=True, + namespaces=("bar",), ), - "{bar}o": XmlVarFactory.create( + "{bar}float": XmlVarFactory.create( index=5, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}o", - nillable=False, + name="choice", + qname="{bar}float", + types=(float,), list_element=True, - types=(object,), namespaces=("bar",), + ), + "{bar}qname": XmlVarFactory.create( + index=6, + name="choice", + qname="{bar}qname", + types=(QName,), + list_element=True, + namespaces=("bar",), + ), + "{bar}tokens": XmlVarFactory.create( + index=7, + name="choice", + qname="{bar}tokens", + types=(int,), + tokens=True, derived=True, - any_type=True, + list_element=True, + default=return_true, + namespaces=("bar",), + ), + "{foo}union": XmlVarFactory.create( + index=8, + name="choice", + qname="{foo}union", + types=(UnionType,), + clazz=UnionType, + list_element=True, + namespaces=("foo",), ), "{bar}p": XmlVarFactory.create( - index=6, - xml_type=XmlType.ELEMENT, - name="compound", + index=9, + name="choice", qname="{bar}p", types=(float,), + derived=True, list_element=True, - namespaces=("bar",), default=1.1, + namespaces=("bar",), ), }, wildcards=[ XmlVarFactory.create( - index=7, + index=10, + name="choice", xml_type=XmlType.WILDCARD, - name="compound", qname="{http://www.w3.org/1999/xhtml}any", types=(object,), - namespaces=("http://www.w3.org/1999/xhtml",), - derived=True, - any_type=False, list_element=True, - ) + default=None, + namespaces=("http://www.w3.org/1999/xhtml",), + ), ], - types=(object,), ) + self.assertEqual(expected, actual) def test_build_validates_result(self): @@ -455,37 +480,3 @@ def test_is_valid(self): XmlType.TEXT, None, None, (int, uuid.UUID), False, False ) ) - - -@dataclass -class CompoundFieldExample: - - compound: List[object] = field( - default_factory=list, - metadata={ - "type": "Elements", - "choices": ( - { - "name": "node", - "type": Type["CompoundFieldExample"], - "namespace": "foo", - }, - { - "name": "x", - "type": List[str], - "tokens": True, - "default_factory": return_true, - "format": "Nope", - }, - {"name": "y", "type": List[int], "nillable": True}, - {"name": "z", "type": List[int]}, - {"name": "o", "type": object}, - {"name": "p", "type": float, "fixed": True, "default": 1.1}, - { - "wildcard": True, - "type": object, - "namespace": "http://www.w3.org/1999/xhtml", - }, - ), - }, - ) diff --git a/tests/formats/dataclass/test_context.py b/tests/formats/dataclass/test_context.py index 460ef0805..5b8ff04f2 100644 --- a/tests/formats/dataclass/test_context.py +++ b/tests/formats/dataclass/test_context.py @@ -6,8 +6,10 @@ from tests.fixtures.artists import BeginArea from tests.fixtures.books import BookForm from tests.fixtures.books import BooksForm +from tests.fixtures.models import BaseType from tests.fixtures.models import ChoiceType from tests.fixtures.models import TypeA +from tests.fixtures.models import TypeC from tests.fixtures.models import UnionType from xsdata.formats.dataclass.context import XmlContext from xsdata.models.enums import DataType @@ -107,10 +109,10 @@ def test_is_derived(self): def test_build_recursive(self): self.ctx.build_recursive(ChoiceType) - self.assertEqual(3, len(self.ctx.cache)) + self.assertEqual(6, len(self.ctx.cache)) - self.ctx.build_recursive(TypeA) - self.assertEqual(3, len(self.ctx.cache)) + self.ctx.build_recursive(BaseType) + self.assertEqual(8, len(self.ctx.cache)) self.ctx.build_recursive(UnionType) - self.assertEqual(6, len(self.ctx.cache)) + self.assertEqual(8, len(self.ctx.cache)) diff --git a/tests/formats/dataclass/test_elements.py b/tests/formats/dataclass/test_elements.py index 040d50c31..063cf22bf 100644 --- a/tests/formats/dataclass/test_elements.py +++ b/tests/formats/dataclass/test_elements.py @@ -43,7 +43,9 @@ def test_property_is_clazz_union(self): def test_property_element_types(self): meta = self.context.build(ChoiceType) var = meta.choices[0] - self.assertEqual({TypeA, TypeB, int, float, QName}, var.element_types) + self.assertEqual( + {TypeA, TypeB, int, float, QName, UnionType}, var.element_types + ) def test_find_choice(self): var = XmlVarFactory.create( diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index 1b44b476c..4c14d41af 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -114,7 +114,6 @@ def build_vars( ): """Build the binding metadata for the given dataclass fields.""" type_hints = get_type_hints(clazz) - globalns = sys.modules[clazz.__module__].__dict__ builder = XmlVarBuilder( class_type=self.class_type, parent_ns=parent_ns, @@ -123,19 +122,28 @@ def build_vars( attribute_name_generator=attribute_name_generator, ) - for index, _field in enumerate(self.class_type.get_fields(clazz)): + for index, field in enumerate(self.class_type.get_fields(clazz)): var = builder.build( index, - _field.name, - type_hints[_field.name], - _field.metadata, - _field.init, - self.class_type.default_value(_field), - globalns, + field.name, + type_hints[field.name], + field.metadata, + field.init, + self.class_type.default_value(field), + self.find_globalns(clazz, field.name), ) if var is not None: yield var + @classmethod + def find_globalns(cls, clazz: Type, name: str) -> Optional[Dict]: + for base in clazz.__mro__: + ann = base.__dict__.get("__annotations__") + if ann and name in ann: + return sys.modules[base.__module__].__dict__ + + return None + @classmethod def build_target_qname(cls, clazz: Type, element_name_generator: Callable) -> str: """Build the source qualified name of a model based on the module From fb8ae74e69ac4b9821da333192b921528afa38f5 Mon Sep 17 00:00:00 2001 From: Chris Date: Sun, 20 Jun 2021 19:28:18 +0300 Subject: [PATCH 4/7] Compound fields rename override conflicts --- .../test_attribute_compound_choice.py | 40 +++++++++--------- xsdata/codegen/container.py | 2 +- .../handlers/attribute_compound_choice.py | 28 +++++++++---- .../handlers/attribute_name_conflict.py | 5 ++- .../codegen/handlers/attribute_overrides.py | 41 ++++++++----------- xsdata/codegen/mixins.py | 13 ++++++ xsdata/codegen/models.py | 5 +++ xsdata/codegen/utils.py | 19 ++++++--- 8 files changed, 92 insertions(+), 61 deletions(-) diff --git a/tests/codegen/handlers/test_attribute_compound_choice.py b/tests/codegen/handlers/test_attribute_compound_choice.py index 80fff45bb..10a5335d4 100644 --- a/tests/codegen/handlers/test_attribute_compound_choice.py +++ b/tests/codegen/handlers/test_attribute_compound_choice.py @@ -1,11 +1,14 @@ from unittest import mock +from xsdata.codegen.container import ClassContainer from xsdata.codegen.handlers import AttributeCompoundChoiceHandler from xsdata.codegen.models import Restrictions +from xsdata.models.config import GeneratorConfig from xsdata.models.enums import DataType from xsdata.utils.testing import AttrFactory from xsdata.utils.testing import AttrTypeFactory from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import ExtensionFactory from xsdata.utils.testing import FactoryTestCase @@ -13,7 +16,8 @@ class AttributeCompoundChoiceHandlerTests(FactoryTestCase): def setUp(self): super().setUp() - self.processor = AttributeCompoundChoiceHandler() + self.container = ClassContainer(config=GeneratorConfig()) + self.processor = AttributeCompoundChoiceHandler(container=self.container) @mock.patch.object(AttributeCompoundChoiceHandler, "group_fields") def test_process(self, mock_group_fields): @@ -87,29 +91,27 @@ def test_group_fields_with_effective_choices_sums_occurs(self): self.assertEqual(1, len(target.attrs)) self.assertEqual(expected_res, target.attrs[0].restrictions) - def test_group_fields_limit_name(self): - target = ClassFactory.create(attrs=AttrFactory.list(3)) - for attr in target.attrs: - attr.restrictions.choice = "1" + def test_choose_name(self): + target = ClassFactory.create() - self.processor.group_fields(target, list(target.attrs)) + actual = self.processor.choose_name(target, ["a", "b", "c"]) + self.assertEqual("a_Or_b_Or_c", actual) - self.assertEqual(1, len(target.attrs)) - self.assertEqual("attr_B_Or_attr_C_Or_attr_D", target.attrs[0].name) + actual = self.processor.choose_name(target, ["a", "b", "c", "d"]) + self.assertEqual("choice", actual) - target = ClassFactory.create(attrs=AttrFactory.list(4)) - for attr in target.attrs: - attr.restrictions.choice = "1" + target.attrs.append(AttrFactory.create(name="choice")) + actual = self.processor.choose_name(target, ["a", "b", "c", "d"]) + self.assertEqual("choice_1", actual) - self.processor.group_fields(target, list(target.attrs)) - self.assertEqual("choice", target.attrs[0].name) + base = ClassFactory.create() + base.attrs.append(AttrFactory.create(name="Choice!")) + target.extensions.append(ExtensionFactory.reference(base.qname)) + self.container.extend((target, base)) - target = ClassFactory.create() - attr = AttrFactory.element(restrictions=Restrictions(choice="1")) - target.attrs.append(attr) - target.attrs.append(attr.clone()) - self.processor.group_fields(target, list(target.attrs)) - self.assertEqual("choice", target.attrs[0].name) + target.attrs.clear() + actual = self.processor.choose_name(target, ["a", "b", "c", "d"]) + self.assertEqual("choice_1", actual) def test_build_attr_choice(self): attr = AttrFactory.create( diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 5c417bb20..854fea4e8 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -70,7 +70,7 @@ def __init__(self, config: GeneratorConfig): ] if self.config.output.compound_fields: - self.post_processors.insert(0, AttributeCompoundChoiceHandler()) + self.post_processors.insert(0, AttributeCompoundChoiceHandler(self)) def __iter__(self) -> Iterator[Class]: """Create an iterator for the class map values.""" diff --git a/xsdata/codegen/handlers/attribute_compound_choice.py b/xsdata/codegen/handlers/attribute_compound_choice.py index b0cf3023b..1a99bde64 100644 --- a/xsdata/codegen/handlers/attribute_compound_choice.py +++ b/xsdata/codegen/handlers/attribute_compound_choice.py @@ -1,17 +1,18 @@ from operator import attrgetter from typing import List -from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.mixins import RelativeHandlerInterface from xsdata.codegen.models import Attr from xsdata.codegen.models import AttrType from xsdata.codegen.models import Class from xsdata.codegen.models import Restrictions +from xsdata.codegen.utils import ClassUtils from xsdata.models.enums import DataType from xsdata.models.enums import Tag from xsdata.utils.collections import group_by -class AttributeCompoundChoiceHandler(HandlerInterface): +class AttributeCompoundChoiceHandler(RelativeHandlerInterface): """Group attributes that belong in the same choice and replace them by compound fields.""" @@ -23,13 +24,13 @@ def process(self, target: Class): if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs): self.group_fields(target, attrs) - @classmethod - def group_fields(cls, target: Class, attrs: List[Attr]): + def group_fields(self, target: Class, attrs: List[Attr]): """Group attributes into a new compound field.""" pos = target.attrs.index(attrs[0]) choice = attrs[0].restrictions.choice sum_occurs = choice and choice.startswith("effective_") + names = [] choices = [] min_occurs = [] @@ -39,12 +40,9 @@ def group_fields(cls, target: Class, attrs: List[Attr]): names.append(attr.local_name) min_occurs.append(attr.restrictions.min_occurs or 0) max_occurs.append(attr.restrictions.max_occurs or 0) - choices.append(cls.build_attr_choice(attr)) + choices.append(self.build_attr_choice(attr)) - if len(names) > 3 or len(names) != len(set(names)): - name = "choice" - else: - name = "_Or_".join(names) + name = self.choose_name(target, names) target.attrs.insert( pos, @@ -61,6 +59,18 @@ def group_fields(cls, target: Class, attrs: List[Attr]): ), ) + def choose_name(self, target: Class, names: List[str]) -> str: + slug_getter = attrgetter("slug") + reserved = set(map(slug_getter, self.base_attrs(target))) + reserved.update(map(slug_getter, target.attrs)) + + if len(names) > 3 or len(names) != len(set(names)): + name = "choice" + else: + name = "_Or_".join(names) + + return ClassUtils.unique_name(name, reserved) + @classmethod def build_attr_choice(cls, attr: Attr) -> Attr: """ diff --git a/xsdata/codegen/handlers/attribute_name_conflict.py b/xsdata/codegen/handlers/attribute_name_conflict.py index 6d9ce4823..018b6b1fc 100644 --- a/xsdata/codegen/handlers/attribute_name_conflict.py +++ b/xsdata/codegen/handlers/attribute_name_conflict.py @@ -1,7 +1,8 @@ +from operator import attrgetter + from xsdata.codegen.mixins import HandlerInterface from xsdata.codegen.models import Class from xsdata.codegen.utils import ClassUtils -from xsdata.utils import text from xsdata.utils.collections import group_by @@ -13,7 +14,7 @@ class AttributeNameConflictHandler(HandlerInterface): def process(self, target: Class): """Sanitize duplicate attribute names that might exist by applying rename strategies.""" - grouped = group_by(target.attrs, lambda attr: text.alnum(attr.name)) + grouped = group_by(target.attrs, key=attrgetter("slug")) for items in grouped.values(): total = len(items) if total == 2 and not items[0].is_enumeration: diff --git a/xsdata/codegen/handlers/attribute_overrides.py b/xsdata/codegen/handlers/attribute_overrides.py index 30a612438..521129a29 100644 --- a/xsdata/codegen/handlers/attribute_overrides.py +++ b/xsdata/codegen/handlers/attribute_overrides.py @@ -1,12 +1,13 @@ import sys +from operator import attrgetter +from typing import Dict +from typing import List from xsdata.codegen.mixins import RelativeHandlerInterface from xsdata.codegen.models import Attr from xsdata.codegen.models import Class -from xsdata.codegen.models import Extension from xsdata.codegen.utils import ClassUtils from xsdata.utils import collections -from xsdata.utils.text import alnum class AttributeOverridesHandler(RelativeHandlerInterface): @@ -22,31 +23,21 @@ class AttributeOverridesHandler(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): - for extension in target.extensions: - self.process_extension(target, extension) - - def process_extension(self, target: Class, extension: Extension): - source = self.container.find(extension.type.qname) - assert source is not None + base_attrs_map = self.base_attrs_map(target) for attr in list(target.attrs): - search = alnum(attr.name) - source_attr = collections.first( - source_attr - for source_attr in source.attrs - if alnum(source_attr.name) == search - ) - - if not source_attr: - continue - - if attr.tag == source_attr.tag: - self.validate_override(target, attr, source_attr) - else: - self.resolve_conflict(attr, source_attr) - - for extension in source.extensions: - self.process_extension(target, extension) + base_attrs = base_attrs_map.get(attr.slug) + + if base_attrs: + base_attr = base_attrs[0] + if attr.tag == base_attr.tag: + self.validate_override(target, attr, base_attr) + else: + self.resolve_conflict(attr, base_attr) + + def base_attrs_map(self, target: Class) -> Dict[str, List[Attr]]: + base_attrs = self.base_attrs(target) + return collections.group_by(base_attrs, key=attrgetter("slug")) @classmethod def validate_override(cls, target: Class, attr: Attr, source_attr: Attr): diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index 541cd66e8..79073ff56 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -5,6 +5,7 @@ from typing import List from typing import Optional +from xsdata.codegen.models import Attr from xsdata.codegen.models import Class from xsdata.models.config import GeneratorConfig from xsdata.utils.constants import return_true @@ -64,6 +65,18 @@ class RelativeHandlerInterface(HandlerInterface, metaclass=ABCMeta): def __init__(self, container: ContainerInterface): self.container = container + def base_attrs(self, target: Class) -> List[Attr]: + attrs: List[Attr] = [] + for extension in target.extensions: + base = self.container.find(extension.type.qname) + + assert base is not None + + attrs.extend(base.attrs) + attrs.extend(self.base_attrs(base)) + + return attrs + class ContainerHandlerInterface(abc.ABC): """Class container.""" diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 844790d92..162f8e671 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -19,6 +19,7 @@ from xsdata.models.enums import Tag from xsdata.models.mixins import ElementBase from xsdata.utils import namespaces +from xsdata.utils import text xml_type_map = { Tag.ANY: XmlType.WILDCARD, @@ -343,6 +344,10 @@ def native_types(self) -> List[Type]: return list(result) + @property + def slug(self) -> str: + return text.alnum(self.name) + @property def xml_type(self) -> Optional[str]: """Return the xml node type this attribute is mapped to.""" diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 7c112094b..a6280171d 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -1,7 +1,9 @@ import sys +from operator import attrgetter from typing import Iterator from typing import List from typing import Optional +from typing import Set from xsdata.codegen.models import Attr from xsdata.codegen.models import AttrType @@ -239,12 +241,19 @@ def rename_attribute_by_preference(cls, a: Attr, b: Attr): def rename_attributes_by_index(cls, attrs: List[Attr], rename: List[Attr]): """Append the next available index number to all the rename attributes names.""" + slug_getter = attrgetter("slug") for index in range(1, len(rename)): - num = 1 + reserved = set(map(slug_getter, attrs)) name = rename[index].name + rename[index].name = cls.unique_name(name, reserved) - reserved = {text.alnum(attr.name) for attr in attrs} - while text.alnum(f"{name}_{num}") in reserved: - num += 1 + @classmethod + def unique_name(cls, name: str, reserved: Set[str]) -> str: + if text.alnum(name) in reserved: + index = 1 + while text.alnum(f"{name}_{index}") in reserved: + index += 1 + + return f"{name}_{index}" - rename[index].name = f"{name}_{num}" + return name From 09ff172759a59f3ac7335fc1f7e7791741d018de Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 21 Jun 2021 19:14:14 +0300 Subject: [PATCH 5/7] Split class analyzer handlers to more steps Notes: Some handlers require the one of the previous ones to run fully for all the classes to be fully effective. --- .../test_attribute_compound_choice.py | 41 +++++- .../codegen/handlers/test_attribute_group.py | 2 +- .../handlers/test_attribute_overrides.py | 6 +- .../handlers/test_attribute_restrictions.py | 37 ------ tests/codegen/handlers/test_attribute_type.py | 6 +- .../codegen/handlers/test_class_extension.py | 2 +- tests/codegen/mappers/test_definitions.py | 4 +- tests/codegen/test_container.py | 102 +++++++-------- xsdata/codegen/container.py | 120 +++++++++--------- .../handlers/attribute_compound_choice.py | 41 +++++- .../handlers/attribute_restrictions.py | 25 +--- xsdata/codegen/handlers/attribute_type.py | 2 +- xsdata/codegen/mappers/definitions.py | 4 +- xsdata/codegen/models.py | 11 +- 14 files changed, 212 insertions(+), 191 deletions(-) diff --git a/tests/codegen/handlers/test_attribute_compound_choice.py b/tests/codegen/handlers/test_attribute_compound_choice.py index 10a5335d4..48fb1f0c7 100644 --- a/tests/codegen/handlers/test_attribute_compound_choice.py +++ b/tests/codegen/handlers/test_attribute_compound_choice.py @@ -2,6 +2,7 @@ from xsdata.codegen.container import ClassContainer from xsdata.codegen.handlers import AttributeCompoundChoiceHandler +from xsdata.codegen.models import Class from xsdata.codegen.models import Restrictions from xsdata.models.config import GeneratorConfig from xsdata.models.enums import DataType @@ -16,7 +17,9 @@ class AttributeCompoundChoiceHandlerTests(FactoryTestCase): def setUp(self): super().setUp() - self.container = ClassContainer(config=GeneratorConfig()) + self.config = GeneratorConfig() + self.config.output.compound_fields = True + self.container = ClassContainer(config=self.config) self.processor = AttributeCompoundChoiceHandler(container=self.container) @mock.patch.object(AttributeCompoundChoiceHandler, "group_fields") @@ -154,3 +157,39 @@ def test_build_attr_choice(self): self.assertEqual(expected_res, actual.restrictions) self.assertEqual(attr.help, actual.help) self.assertFalse(actual.fixed) + + def test_reset_sequential(self): + def len_sequential(target: Class): + return len([attr for attr in target.attrs if attr.restrictions.sequential]) + + restrictions = Restrictions(max_occurs=2, sequential=True) + target = ClassFactory.create( + attrs=[ + AttrFactory.create(restrictions=restrictions.clone()), + AttrFactory.create(restrictions=restrictions.clone()), + ] + ) + + attrs_clone = [attr.clone() for attr in target.attrs] + + self.processor.compound_fields = False + self.processor.reset_sequential(target, 0) + self.assertEqual(2, len_sequential(target)) + + target.attrs[0].restrictions.sequential = False + self.processor.reset_sequential(target, 0) + self.assertEqual(1, len_sequential(target)) + + self.processor.reset_sequential(target, 1) + self.assertEqual(0, len_sequential(target)) + + target.attrs = attrs_clone + target.attrs[1].restrictions.sequential = False + self.processor.reset_sequential(target, 0) + self.assertEqual(0, len_sequential(target)) + + target.attrs[0].restrictions.sequential = True + target.attrs[0].restrictions.max_occurs = 0 + target.attrs[1].restrictions.sequential = True + self.processor.reset_sequential(target, 0) + self.assertEqual(1, len_sequential(target)) diff --git a/tests/codegen/handlers/test_attribute_group.py b/tests/codegen/handlers/test_attribute_group.py index 9711c741a..2efa7c2b9 100644 --- a/tests/codegen/handlers/test_attribute_group.py +++ b/tests/codegen/handlers/test_attribute_group.py @@ -83,7 +83,7 @@ def test_process_attribute_with_circular_reference(self): target = ClassFactory.create(qname="bar", tag=Tag.GROUP) target.attrs.append(group_attr) - target.status = Status.PROCESSING + target.status = Status.FLATTENING self.processor.container.add(target) self.processor.process_attribute(target, group_attr) diff --git a/tests/codegen/handlers/test_attribute_overrides.py b/tests/codegen/handlers/test_attribute_overrides.py index 3847150a6..2aa440075 100644 --- a/tests/codegen/handlers/test_attribute_overrides.py +++ b/tests/codegen/handlers/test_attribute_overrides.py @@ -23,15 +23,15 @@ def setUp(self): @mock.patch.object(AttributeOverridesHandler, "validate_override") def test_process(self, mock_validate_override, mock_resolve_conflict): class_a = ClassFactory.create( - status=Status.PROCESSING, + status=Status.FLATTENING, attrs=[ AttrFactory.create(name="el", tag=Tag.ELEMENT), AttrFactory.create(name="at", tag=Tag.ATTRIBUTE), ], ) - class_b = ClassFactory.elements(2, status=Status.PROCESSED) - class_c = ClassFactory.create(status=Status.PROCESSED) + class_b = ClassFactory.elements(2, status=Status.FLATTENED) + class_c = ClassFactory.create(status=Status.FLATTENED) class_b.extensions.append(ExtensionFactory.reference(class_c.qname)) class_a.extensions.append(ExtensionFactory.reference(class_b.qname)) diff --git a/tests/codegen/handlers/test_attribute_restrictions.py b/tests/codegen/handlers/test_attribute_restrictions.py index ea201de31..b38fc00ea 100644 --- a/tests/codegen/handlers/test_attribute_restrictions.py +++ b/tests/codegen/handlers/test_attribute_restrictions.py @@ -1,9 +1,7 @@ from xsdata.codegen.handlers import AttributeRestrictionsHandler -from xsdata.codegen.models import Class from xsdata.codegen.models import Restrictions from xsdata.models.enums import Tag from xsdata.utils.testing import AttrFactory -from xsdata.utils.testing import ClassFactory from xsdata.utils.testing import FactoryTestCase @@ -75,38 +73,3 @@ def test_reset_occurrences(self): attr.restrictions.nillable = True self.processor.reset_occurrences(attr) self.assertIsNone(attr.restrictions.required) - - def test_reset_sequential(self): - def len_sequential(target: Class): - return len([attr for attr in target.attrs if attr.restrictions.sequential]) - - restrictions = Restrictions(max_occurs=2, sequential=True) - target = ClassFactory.create( - attrs=[ - AttrFactory.create(restrictions=restrictions.clone()), - AttrFactory.create(restrictions=restrictions.clone()), - ] - ) - - attrs_clone = [attr.clone() for attr in target.attrs] - - self.processor.reset_sequential(target, 0) - self.assertEqual(2, len_sequential(target)) - - target.attrs[0].restrictions.sequential = False - self.processor.reset_sequential(target, 0) - self.assertEqual(1, len_sequential(target)) - - self.processor.reset_sequential(target, 1) - self.assertEqual(0, len_sequential(target)) - - target.attrs = attrs_clone - target.attrs[1].restrictions.sequential = False - self.processor.reset_sequential(target, 0) - self.assertEqual(0, len_sequential(target)) - - target.attrs[0].restrictions.sequential = True - target.attrs[0].restrictions.max_occurs = 0 - target.attrs[1].restrictions.sequential = True - self.processor.reset_sequential(target, 0) - self.assertEqual(1, len_sequential(target)) diff --git a/tests/codegen/handlers/test_attribute_type.py b/tests/codegen/handlers/test_attribute_type.py index 836a1f566..e0bf6fef7 100644 --- a/tests/codegen/handlers/test_attribute_type.py +++ b/tests/codegen/handlers/test_attribute_type.py @@ -182,7 +182,7 @@ def test_process_inner_type_with_simple_type( self, mock_copy_attribute_properties, mock_update_restrictions ): attr = AttrFactory.create(types=[AttrTypeFactory.create(qname="{bar}a")]) - inner = ClassFactory.simple_type(qname="{bar}a", status=Status.PROCESSED) + inner = ClassFactory.simple_type(qname="{bar}a", status=Status.FLATTENED) target = ClassFactory.create(inner=[inner]) self.processor.process_inner_type(target, attr, attr.types[0]) @@ -199,7 +199,7 @@ def test_process_inner_type_with_complex_type( self, mock_copy_attribute_properties, mock_update_restrictions ): target = ClassFactory.create() - inner = ClassFactory.elements(2, qname="a", status=Status.PROCESSED) + inner = ClassFactory.elements(2, qname="a", status=Status.FLATTENED) attr = AttrFactory.create(types=[AttrTypeFactory.create(qname="a")]) target.inner.append(inner) @@ -281,7 +281,7 @@ def test_is_circular_dependency(self, mock_dependencies, mock_container_find): source = ClassFactory.create() target = ClassFactory.create() another = ClassFactory.create() - processing = ClassFactory.create(status=Status.PROCESSING) + processing = ClassFactory.create(status=Status.FLATTENING) find_classes = {"a": another, "b": target} diff --git a/tests/codegen/handlers/test_class_extension.py b/tests/codegen/handlers/test_class_extension.py index f50948e21..be49ed2d6 100644 --- a/tests/codegen/handlers/test_class_extension.py +++ b/tests/codegen/handlers/test_class_extension.py @@ -150,7 +150,7 @@ def test_process_enum_extension_with_complex_source(self): AttrFactory.create(tag=Tag.RESTRICTION), ], extensions=ExtensionFactory.list(2), - status=Status.PROCESSED, + status=Status.FLATTENED, ) target = ClassFactory.enumeration(1) target.attrs[0].default = "Yes" diff --git a/tests/codegen/mappers/test_definitions.py b/tests/codegen/mappers/test_definitions.py index f0c9bd788..9d2bdcd41 100644 --- a/tests/codegen/mappers/test_definitions.py +++ b/tests/codegen/mappers/test_definitions.py @@ -170,7 +170,7 @@ def test_map_binding_operation( other = ClassFactory.create() service = ClassFactory.create( qname=build_qname("xsdata", "Calc_Add"), - status=Status.PROCESSED, + status=Status.FLATTENED, tag=Tag.BINDING_OPERATION, location="foo.wsdl", module=None, @@ -696,7 +696,7 @@ def test_build_message_class(self, mock_create_message_attributes): actual = DefinitionsMapper.build_message_class(definitions, port_type_message) expected = Class( qname=build_qname("xsdata", "bar"), - status=Status.PROCESSED, + status=Status.FLATTENED, tag=Tag.ELEMENT, location="foo.wsdl", ns_map=message.ns_map, diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index 4f6878773..4c4a248a1 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -1,10 +1,12 @@ from unittest import mock from xsdata.codegen.container import ClassContainer +from xsdata.codegen.container import Steps from xsdata.codegen.models import Class from xsdata.codegen.models import Status from xsdata.models.config import GeneratorConfig from xsdata.models.enums import Tag +from xsdata.utils.testing import AttrFactory from xsdata.utils.testing import ClassFactory from xsdata.utils.testing import FactoryTestCase @@ -21,7 +23,8 @@ def test_initialize(self): ClassFactory.create(qname="{xsdata}foo", tag=Tag.COMPLEX_TYPE), ClassFactory.create(qname="{xsdata}foobar", tag=Tag.COMPLEX_TYPE), ] - container = ClassContainer(config=GeneratorConfig()) + config = GeneratorConfig() + container = ClassContainer(config) container.extend(classes) expected = { @@ -33,13 +36,13 @@ def test_initialize(self): self.assertEqual(3, len(list(container))) self.assertEqual(classes, list(container)) - self.assertEqual( - ["ClassNameConflictHandler", "ClassDesignateHandler"], - [x.__class__.__name__ for x in container.collection_processors], - ) + actual = { + step: [processor.__class__.__name__ for processor in processors] + for step, processors in container.processors.items() + } - self.assertEqual( - [ + expected = { + 10: [ "AttributeGroupHandler", "ClassExtensionHandler", "ClassEnumerationHandler", @@ -48,46 +51,30 @@ def test_initialize(self): "AttributeMergeHandler", "AttributeMixedContentHandler", "AttributeDefaultValidateHandler", - "AttributeOverridesHandler", - "AttributeEffectiveChoiceHandler", ], - [x.__class__.__name__ for x in container.pre_processors], - ) - - self.assertEqual( - [ - "AttributeDefaultValueHandler", + 20: [ + "AttributeEffectiveChoiceHandler", "AttributeRestrictionsHandler", - "AttributeNameConflictHandler", - "ClassInnersHandler", - ], - [x.__class__.__name__ for x in container.post_processors], - ) - - config = GeneratorConfig() - config.output.compound_fields = True - container = ClassContainer(config=config) - - self.assertEqual( - [ - "AttributeCompoundChoiceHandler", "AttributeDefaultValueHandler", - "AttributeRestrictionsHandler", + ], + 30: [ + "AttributeOverridesHandler", "AttributeNameConflictHandler", - "ClassInnersHandler", ], - [x.__class__.__name__ for x in container.post_processors], - ) + 40: ["ClassInnersHandler", "AttributeCompoundChoiceHandler"], + } + + self.assertEqual(expected, actual) - @mock.patch.object(ClassContainer, "pre_process_class") - def test_find(self, mock_pre_process_class): - def pre_process_class(x: Class): - x.status = Status.PROCESSED + @mock.patch.object(ClassContainer, "process_class") + def test_find(self, mock_process_class): + def process_class(x: Class, step: int): + x.status = Status.FLATTENED class_a = ClassFactory.create(qname="a") - class_b = ClassFactory.create(qname="b", status=Status.PROCESSED) - class_c = ClassFactory.enumeration(2, qname="b", status=Status.PROCESSING) - mock_pre_process_class.side_effect = pre_process_class + class_b = ClassFactory.create(qname="b", status=Status.FLATTENED) + class_c = ClassFactory.enumeration(2, qname="b", status=Status.FLATTENING) + mock_process_class.side_effect = process_class self.container.extend([class_a, class_b, class_c]) self.assertIsNone(self.container.find("nope")) @@ -96,34 +83,47 @@ def pre_process_class(x: Class): self.assertEqual( class_c, self.container.find(class_b.qname, lambda x: x.is_enumeration) ) - mock_pre_process_class.assert_called_once_with(class_a) + mock_process_class.assert_called_once_with(class_a, Steps.FLATTEN) - @mock.patch.object(ClassContainer, "pre_process_class") - def test_find_inner(self, mock_pre_process_class): + @mock.patch.object(ClassContainer, "process_class") + def test_find_inner(self, mock_process_class): obj = ClassFactory.create() first = ClassFactory.create(qname="{a}a") - second = ClassFactory.create(qname="{a}b", status=Status.PROCESSED) + second = ClassFactory.create(qname="{a}b", status=Status.FLATTENED) obj.inner.extend((first, second)) - def pre_process_class(x: Class): - x.status = Status.PROCESSED + def process_class(x: Class, step: int): + x.status = Status.FLATTENED - mock_pre_process_class.side_effect = pre_process_class + mock_process_class.side_effect = process_class self.assertEqual(first, self.container.find_inner(obj, "{a}a")) self.assertEqual(second, self.container.find_inner(obj, "{a}b")) - mock_pre_process_class.assert_called_once_with(first) + mock_process_class.assert_called_once_with(first, Steps.FLATTEN) - def test_pre_process_class(self): + def test_process_class(self): target = ClassFactory.create( inner=[ClassFactory.elements(2), ClassFactory.elements(1)] ) self.container.add(target) self.container.process() - self.assertEqual(Status.PROCESSED, target.status) - self.assertEqual(Status.PROCESSED, target.inner[0].status) - self.assertEqual(Status.PROCESSED, target.inner[1].status) + self.assertEqual(Status.FINALIZED, target.status) + self.assertEqual(Status.FINALIZED, target.inner[0].status) + self.assertEqual(Status.FINALIZED, target.inner[1].status) + + def test_process_classes(self): + target = ClassFactory.create( + attrs=[AttrFactory.reference("enumeration")], + inner=[ClassFactory.enumeration(2, qname="enumeration")], + ) + + self.container.add(target) + self.container.process_classes(Steps.FLATTEN) + self.assertEqual(2, len(list(self.container))) + + for obj in self.container: + self.assertEqual(Status.FLATTENED, obj.status) @mock.patch.object(Class, "should_generate", new_callable=mock.PropertyMock) def test_filter_classes(self, mock_class_should_generate): diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 854fea4e8..d869d62ec 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -22,9 +22,7 @@ from xsdata.codegen.handlers import ClassExtensionHandler from xsdata.codegen.handlers import ClassInnersHandler from xsdata.codegen.handlers import ClassNameConflictHandler -from xsdata.codegen.mixins import ContainerHandlerInterface from xsdata.codegen.mixins import ContainerInterface -from xsdata.codegen.mixins import HandlerInterface from xsdata.codegen.models import Class from xsdata.codegen.models import Status from xsdata.codegen.utils import ClassUtils @@ -34,9 +32,16 @@ from xsdata.utils.constants import return_true +class Steps: + FLATTEN = 10 + SANITIZE = 20 + RESOLVE = 30 + FINALIZE = 40 + + class ClassContainer(ContainerInterface): - __slots__ = ("data", "pre_processors", "post_processors", "collection_processors") + __slots__ = ("data", "processors") def __init__(self, config: GeneratorConfig): """Initialize a class container instance with its processors based on @@ -44,33 +49,32 @@ def __init__(self, config: GeneratorConfig): super().__init__(config) self.data: Dict = {} - self.pre_processors: List[HandlerInterface] = [ - AttributeGroupHandler(self), - ClassExtensionHandler(self), - ClassEnumerationHandler(self), - AttributeSubstitutionHandler(self), - AttributeTypeHandler(self), - AttributeMergeHandler(), - AttributeMixedContentHandler(), - AttributeDefaultValidateHandler(), - AttributeOverridesHandler(self), - AttributeEffectiveChoiceHandler(), - ] - - self.post_processors: List[HandlerInterface] = [ - AttributeDefaultValueHandler(self), - AttributeRestrictionsHandler(), - AttributeNameConflictHandler(), - ClassInnersHandler(), - ] - - self.collection_processors: List[ContainerHandlerInterface] = [ - ClassNameConflictHandler(self), - ClassDesignateHandler(self), - ] - if self.config.output.compound_fields: - self.post_processors.insert(0, AttributeCompoundChoiceHandler(self)) + self.processors = { + Steps.FLATTEN: [ + AttributeGroupHandler(self), + ClassExtensionHandler(self), + ClassEnumerationHandler(self), + AttributeSubstitutionHandler(self), + AttributeTypeHandler(self), + AttributeMergeHandler(), + AttributeMixedContentHandler(), + AttributeDefaultValidateHandler(), + ], + Steps.SANITIZE: [ + AttributeEffectiveChoiceHandler(), + AttributeRestrictionsHandler(), + AttributeDefaultValueHandler(self), + ], + Steps.RESOLVE: [ + AttributeOverridesHandler(self), + AttributeNameConflictHandler(), + ], + Steps.FINALIZE: [ + ClassInnersHandler(), + AttributeCompoundChoiceHandler(self), + ], + } def __iter__(self) -> Iterator[Class]: """Create an iterator for the class map values.""" @@ -83,7 +87,7 @@ def find(self, qname: str, condition: Callable = return_true) -> Optional[Class] for row in self.data.get(qname, []): if condition(row): if row.status == Status.RAW: - self.pre_process_class(row) + self.process_class(row, Steps.FLATTEN) return self.find(qname, condition) return row @@ -92,7 +96,7 @@ def find(self, qname: str, condition: Callable = return_true) -> Optional[Class] def find_inner(self, source: Class, qname: str) -> Class: inner = ClassUtils.find_inner(source, qname) if inner.status == Status.RAW: - self.pre_process_class(inner) + self.process_class(inner, Steps.FLATTEN) return inner @@ -101,46 +105,48 @@ def process(self): Run all the process handlers. Steps - 1. Run all pre-selection handlers + 1. Flatten extensions, attribute types 2. Filter classes to be actually generated - 3. Run all post-selection handlers - 4. Resolve any naming conflicts - 5. Assign packages and modules + 3. Sanitize attributes and extensions + 4. Resolve attributes conflicts + 5. Replace repeatable elements with compound fields + 6. Designate packages and modules """ - for obj in self: - if obj.status == Status.RAW: - self.pre_process_class(obj) + self.process_classes(Steps.FLATTEN) self.filter_classes() + self.process_classes(Steps.SANITIZE) + self.process_classes(Steps.RESOLVE) + self.process_classes(Steps.FINALIZE) + self.designate_classes() + def process_classes(self, step: int) -> None: for obj in self: - self.post_process_class(obj) + if obj.status < step: + self.process_class(obj, step) - for handler in self.collection_processors: - handler.run() + if any(obj.status < step for obj in self): + return self.process_classes(step) - def pre_process_class(self, target: Class): - """Run the pre process handlers for the target class.""" - target.status = Status.PROCESSING - - for processor in self.pre_processors: + def process_class(self, target: Class, step: int): + target.status = Status(step) + for processor in self.processors.get(step, []): processor.process(target) - # We go top to bottom because it's easier to handle circular - # references. for inner in target.inner: - if inner.status == Status.RAW: - self.pre_process_class(inner) + if inner.status < step: + self.process_class(inner, step) - target.status = Status.PROCESSED + target.status = Status(step + 1) - def post_process_class(self, target: Class): - """Run the post process handlers for the target class.""" - for inner in target.inner: - self.post_process_class(inner) + def designate_classes(self): + designators = [ + ClassNameConflictHandler(self), + ClassDesignateHandler(self), + ] - for processor in self.post_processors: - processor.process(target) + for designator in designators: + designator.run() def filter_classes(self): """If there is any class derived from complexType or element then diff --git a/xsdata/codegen/handlers/attribute_compound_choice.py b/xsdata/codegen/handlers/attribute_compound_choice.py index 1a99bde64..b32bacec4 100644 --- a/xsdata/codegen/handlers/attribute_compound_choice.py +++ b/xsdata/codegen/handlers/attribute_compound_choice.py @@ -1,6 +1,7 @@ from operator import attrgetter from typing import List +from xsdata.codegen.mixins import ContainerInterface from xsdata.codegen.mixins import RelativeHandlerInterface from xsdata.codegen.models import Attr from xsdata.codegen.models import AttrType @@ -16,13 +17,22 @@ class AttributeCompoundChoiceHandler(RelativeHandlerInterface): """Group attributes that belong in the same choice and replace them by compound fields.""" - __slots__ = () + __slots__ = "compound_fields" + + def __init__(self, container: ContainerInterface): + super().__init__(container) + + self.compound_fields = container.config.output.compound_fields def process(self, target: Class): - groups = group_by(target.attrs, attrgetter("restrictions.choice")) - for choice, attrs in groups.items(): - if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs): - self.group_fields(target, attrs) + if self.compound_fields: + groups = group_by(target.attrs, attrgetter("restrictions.choice")) + for choice, attrs in groups.items(): + if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs): + self.group_fields(target, attrs) + + for index in range(len(target.attrs)): + self.reset_sequential(target, index) def group_fields(self, target: Class, attrs: List[Attr]): """Group attributes into a new compound field.""" @@ -93,3 +103,24 @@ def build_attr_choice(cls, attr: Attr) -> Attr: help=attr.help, restrictions=restrictions, ) + + @classmethod + def reset_sequential(cls, target: Class, index: int): + """Reset the attribute at the given index if it has no siblings with + the sequential restriction.""" + + attr = target.attrs[index] + before = target.attrs[index - 1] if index - 1 >= 0 else None + after = target.attrs[index + 1] if index + 1 < len(target.attrs) else None + + if not attr.is_list: + attr.restrictions.sequential = False + + if ( + not attr.restrictions.sequential + or (before and before.restrictions.sequential) + or (after and after.restrictions.sequential and after.is_list) + ): + return + + attr.restrictions.sequential = False diff --git a/xsdata/codegen/handlers/attribute_restrictions.py b/xsdata/codegen/handlers/attribute_restrictions.py index cf27e1d00..f4dfe125a 100644 --- a/xsdata/codegen/handlers/attribute_restrictions.py +++ b/xsdata/codegen/handlers/attribute_restrictions.py @@ -9,10 +9,8 @@ class AttributeRestrictionsHandler(HandlerInterface): __slots__ = () def process(self, target: Class): - - for index, attr in enumerate(target.attrs): + for attr in target.attrs: self.reset_occurrences(attr) - self.reset_sequential(target, index) @classmethod def reset_occurrences(cls, attr: Attr): @@ -47,24 +45,3 @@ def reset_occurrences(cls, attr: Attr): if attr.default or attr.fixed or attr.restrictions.nillable: restrictions.required = None - - @classmethod - def reset_sequential(cls, target: Class, index: int): - """Reset the attribute at the given index if it has no siblings with - the sequential restriction.""" - - attr = target.attrs[index] - before = target.attrs[index - 1] if index - 1 >= 0 else None - after = target.attrs[index + 1] if index + 1 < len(target.attrs) else None - - if not attr.is_list: - attr.restrictions.sequential = False - - if ( - not attr.restrictions.sequential - or (before and before.restrictions.sequential) - or (after and after.restrictions.sequential and after.is_list) - ): - return - - attr.restrictions.sequential = False diff --git a/xsdata/codegen/handlers/attribute_type.py b/xsdata/codegen/handlers/attribute_type.py index 37fb39273..5774ac18d 100644 --- a/xsdata/codegen/handlers/attribute_type.py +++ b/xsdata/codegen/handlers/attribute_type.py @@ -171,7 +171,7 @@ def is_circular_dependency(self, source: Class, target: Class, seen: Set) -> boo """Check if any source dependencies recursively match the target class.""" - if source is target or source.status == Status.PROCESSING: + if source is target or source.status == Status.FLATTENING: return True for qname in self.cached_dependencies(source): diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index 2ce2f949e..1ff60212d 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -116,7 +116,7 @@ def map_binding_operation( yield Class( qname=build_qname(definitions.target_namespace, name), - status=Status.PROCESSED, + status=Status.FLATTENED, tag=type(binding_operation).__name__, location=definitions.location, ns_map=binding_operation.ns_map, @@ -250,7 +250,7 @@ def build_message_class( return Class( qname=build_qname(definitions.target_namespace, message_name), - status=Status.PROCESSED, + status=Status.FLATTENED, tag=Tag.ELEMENT, location=definitions.location, ns_map=ns_map, diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 162f8e671..2bf71201f 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -385,9 +385,14 @@ def clone(self) -> "Extension": class Status(IntEnum): RAW = 0 - PROCESSING = 1 - PROCESSED = 2 - SANITIZED = 3 + FLATTENING = 10 + FLATTENED = 11 + SANITIZING = 20 + SANITIZED = 21 + RESOLVING = 30 + RESOLVED = 31 + FINALIZING = 40 + FINALIZED = 41 @dataclass From 2813b67df59145fc6f63cfd5d3fed7edafc29df7 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 22 Jun 2021 00:28:52 +0300 Subject: [PATCH 6/7] Unify required/prohibited restrictions with min/max occurs --- .../test_attribute_compound_choice.py | 2 - .../handlers/test_attribute_default_value.py | 8 ++ .../handlers/test_attribute_overrides.py | 8 +- .../handlers/test_attribute_restrictions.py | 75 ------------------- tests/codegen/handlers/test_attribute_type.py | 14 ++++ .../codegen/handlers/test_class_extension.py | 12 +-- tests/codegen/mappers/test_schema.py | 6 +- tests/codegen/models/test_restrictions.py | 36 ++------- tests/codegen/test_container.py | 1 - tests/fixtures/artists/metadata.py | 6 -- tests/formats/dataclass/test_filters.py | 20 ++++- tests/models/xsd/test_attribute.py | 13 +++- xsdata/codegen/container.py | 2 - xsdata/codegen/handlers/__init__.py | 2 - .../handlers/attribute_default_value.py | 2 +- .../codegen/handlers/attribute_overrides.py | 12 ++- .../handlers/attribute_restrictions.py | 47 ------------ xsdata/codegen/handlers/attribute_type.py | 10 ++- xsdata/codegen/models.py | 35 +++------ xsdata/formats/dataclass/filters.py | 4 + xsdata/models/xsd.py | 10 ++- 21 files changed, 113 insertions(+), 212 deletions(-) delete mode 100644 tests/codegen/handlers/test_attribute_restrictions.py delete mode 100644 xsdata/codegen/handlers/attribute_restrictions.py diff --git a/tests/codegen/handlers/test_attribute_compound_choice.py b/tests/codegen/handlers/test_attribute_compound_choice.py index 48fb1f0c7..87cc40106 100644 --- a/tests/codegen/handlers/test_attribute_compound_choice.py +++ b/tests/codegen/handlers/test_attribute_compound_choice.py @@ -122,8 +122,6 @@ def test_build_attr_choice(self): ) attr.local_name = "aaa" attr.restrictions = Restrictions( - required=True, - prohibited=None, min_occurs=1, max_occurs=1, min_exclusive="1.1", diff --git a/tests/codegen/handlers/test_attribute_default_value.py b/tests/codegen/handlers/test_attribute_default_value.py index 1fcff4afb..3cf145cb1 100644 --- a/tests/codegen/handlers/test_attribute_default_value.py +++ b/tests/codegen/handlers/test_attribute_default_value.py @@ -63,6 +63,14 @@ def test_process_attribute_with_optional_field(self): self.assertFalse(attr.fixed) self.assertIsNone(attr.default) + def test_process_attribute_with_list_field(self): + target = ClassFactory.create() + attr = AttrFactory.create(fixed=True, default=2) + attr.restrictions.max_occurs = 5 + self.processor.process_attribute(target, attr) + self.assertFalse(attr.fixed) + self.assertIsNone(attr.default) + def test_process_attribute_with_xsi_type(self): target = ClassFactory.create() attr = AttrFactory.create( diff --git a/tests/codegen/handlers/test_attribute_overrides.py b/tests/codegen/handlers/test_attribute_overrides.py index 2aa440075..4dce90318 100644 --- a/tests/codegen/handlers/test_attribute_overrides.py +++ b/tests/codegen/handlers/test_attribute_overrides.py @@ -78,14 +78,16 @@ def test_validate_override(self): self.processor.validate_override(target, attr_a, attr_b) self.assertEqual(1, len(target.attrs)) - # restrictions except choice, min/max occurs don't match + # Restrictions don't match attr_b.fixed = attr_a.fixed - attr_a.restrictions.required = not attr_b.restrictions.required + attr_a.restrictions.tokens = not attr_b.restrictions.tokens + attr_a.restrictions.nillable = not attr_b.restrictions.nillable self.processor.validate_override(target, attr_a, attr_b) self.assertEqual(1, len(target.attrs)) # Restrictions are compatible again - attr_a.restrictions.required = attr_b.restrictions.required + attr_a.restrictions.tokens = attr_b.restrictions.tokens + attr_a.restrictions.nillable = attr_b.restrictions.nillable self.processor.validate_override(target, attr_a, attr_b) self.assertEqual(0, len(target.attrs)) diff --git a/tests/codegen/handlers/test_attribute_restrictions.py b/tests/codegen/handlers/test_attribute_restrictions.py deleted file mode 100644 index b38fc00ea..000000000 --- a/tests/codegen/handlers/test_attribute_restrictions.py +++ /dev/null @@ -1,75 +0,0 @@ -from xsdata.codegen.handlers import AttributeRestrictionsHandler -from xsdata.codegen.models import Restrictions -from xsdata.models.enums import Tag -from xsdata.utils.testing import AttrFactory -from xsdata.utils.testing import FactoryTestCase - - -class AttributeRestrictionsHandlerTests(FactoryTestCase): - def setUp(self): - super().setUp() - - self.processor = AttributeRestrictionsHandler() - - def test_reset_occurrences(self): - required = Restrictions(min_occurs=1, max_occurs=1) - attr = AttrFactory.attribute(restrictions=required.clone()) - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - tokens = Restrictions(required=True, tokens=True, min_occurs=1, max_occurs=1) - attr = AttrFactory.element(restrictions=tokens.clone()) - self.processor.reset_occurrences(attr) - self.assertFalse(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - attr = AttrFactory.element(restrictions=tokens.clone()) - attr.restrictions.max_occurs = 2 - self.processor.reset_occurrences(attr) - self.assertFalse(attr.restrictions.required) - self.assertIsNotNone(attr.restrictions.min_occurs) - self.assertIsNotNone(attr.restrictions.max_occurs) - - multiple = Restrictions(min_occurs=0, max_occurs=2) - attr = AttrFactory.create(tag=Tag.EXTENSION, restrictions=multiple) - self.processor.reset_occurrences(attr) - self.assertTrue(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - multiple = Restrictions(max_occurs=2, required=True) - attr = AttrFactory.element(restrictions=multiple, fixed=True) - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.required) - self.assertEqual(0, attr.restrictions.min_occurs) - self.assertFalse(attr.fixed) - - attr = AttrFactory.element(restrictions=required.clone()) - self.processor.reset_occurrences(attr) - self.assertTrue(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - restrictions = Restrictions(required=True, min_occurs=0, max_occurs=1) - attr = AttrFactory.element(restrictions=restrictions, default="A", fixed=True) - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - self.assertIsNone(attr.default) - self.assertFalse(attr.fixed) - - attr = AttrFactory.element(restrictions=required.clone(), default="A") - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.required) - - attr = AttrFactory.element(restrictions=required.clone(), fixed=True) - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.required) - - attr = AttrFactory.element(restrictions=required.clone()) - attr.restrictions.nillable = True - self.processor.reset_occurrences(attr) - self.assertIsNone(attr.restrictions.required) diff --git a/tests/codegen/handlers/test_attribute_type.py b/tests/codegen/handlers/test_attribute_type.py index e0bf6fef7..1081e91a8 100644 --- a/tests/codegen/handlers/test_attribute_type.py +++ b/tests/codegen/handlers/test_attribute_type.py @@ -260,6 +260,20 @@ def test_copy_attribute_properties_from_nillable_source(self): self.processor.copy_attribute_properties(source, target, attr, attr.types[0]) self.assertTrue(attr.restrictions.nillable) + def test_copy_attribute_properties_to_attribute_target(self): + source = ClassFactory.elements(1, nillable=True) + target = ClassFactory.create(attrs=AttrFactory.list(1, tag=Tag.ATTRIBUTE)) + attr = target.attrs[0] + attr.restrictions.min_occurs = 1 + attr.restrictions.max_occurs = 1 + + source.attrs[0].restrictions.min_occurs = 0 + source.attrs[0].restrictions.max_occurs = 1 + + self.assertFalse(attr.is_optional) + self.processor.copy_attribute_properties(source, target, attr, attr.types[0]) + self.assertFalse(attr.is_optional) + @mock.patch.object(AttributeTypeHandler, "is_circular_dependency") def test_set_circular_flag(self, mock_is_circular_dependency): source = ClassFactory.create() diff --git a/tests/codegen/handlers/test_class_extension.py b/tests/codegen/handlers/test_class_extension.py index be49ed2d6..2f58a2a52 100644 --- a/tests/codegen/handlers/test_class_extension.py +++ b/tests/codegen/handlers/test_class_extension.py @@ -431,7 +431,9 @@ def test_replace_attributes_type(self): def test_add_default_attribute(self): xs_string = AttrTypeFactory.native(DataType.STRING) - extension = ExtensionFactory.create(xs_string, Restrictions(required=True)) + extension = ExtensionFactory.create( + xs_string, Restrictions(min_occurs=1, max_occurs=1) + ) item = ClassFactory.elements(1, extensions=[extension]) ClassExtensionHandler.add_default_attribute(item, extension) @@ -449,9 +451,7 @@ def test_add_default_attribute(self): ClassExtensionHandler.add_default_attribute(item, extension) expected.types.append(xs_int) - expected_restrictions = Restrictions( - tokens=True, required=True, min_occurs=1, max_occurs=1 - ) + expected_restrictions = Restrictions(tokens=True, min_occurs=1, max_occurs=1) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) @@ -461,7 +461,7 @@ def test_add_default_attribute(self): def test_add_default_attribute_with_any_type(self): extension = ExtensionFactory.create( AttrTypeFactory.native(DataType.ANY_TYPE), - Restrictions(min_occurs=1, max_occurs=1, required=True), + Restrictions(min_occurs=1, max_occurs=1), ) item = ClassFactory.create(extensions=[extension]) @@ -472,7 +472,7 @@ def test_add_default_attribute_with_any_type(self): types=[extension.type.clone()], tag=Tag.ANY, namespace="##any", - restrictions=Restrictions(min_occurs=1, max_occurs=1, required=True), + restrictions=Restrictions(min_occurs=1, max_occurs=1), ) self.assertEqual(1, len(item.attrs)) diff --git a/tests/codegen/mappers/test_schema.py b/tests/codegen/mappers/test_schema.py index cd045ddcf..2e8514248 100644 --- a/tests/codegen/mappers/test_schema.py +++ b/tests/codegen/mappers/test_schema.py @@ -284,7 +284,7 @@ def test_build_class_attribute( mock_default_value.return_value = "default" mock_is_fixed.return_value = True mock_element_namespace.return_value = "http://something/common" - mock_get_restrictions.return_value = {"required": True} + mock_get_restrictions.return_value = {"min_occurs": 1, "max_occurs": 1} attribute = Attribute(default="false") attribute.index = 66 @@ -300,7 +300,7 @@ def test_build_class_attribute( default=mock_default_value.return_value, fixed=mock_is_fixed.return_value, index=66, - restrictions=Restrictions(required=True), + restrictions=Restrictions(min_occurs=1, max_occurs=1), ) self.assertEqual(expected, item.attrs[0]) self.assertEqual({"bar": "foo", "foo": "bar"}, item.ns_map) @@ -309,7 +309,7 @@ def test_build_class_attribute( def test_build_class_attribute_skip_prohibited(self): item = ClassFactory.create(ns_map={"bar": "foo"}) - attribute = Attribute(default="false", use=UseType.PROHIBITED) + attribute = Attribute(use=UseType.PROHIBITED) SchemaMapper.build_class_attribute(item, attribute, Restrictions()) self.assertEqual(0, len(item.attrs)) diff --git a/tests/codegen/models/test_restrictions.py b/tests/codegen/models/test_restrictions.py index b67613a2f..b820ac03e 100644 --- a/tests/codegen/models/test_restrictions.py +++ b/tests/codegen/models/test_restrictions.py @@ -7,8 +7,6 @@ class RestrictionsTests(TestCase): def setUp(self) -> None: self.restrictions = Restrictions( - required=True, - prohibited=None, min_occurs=1, max_occurs=1, min_exclusive="1.1", @@ -38,7 +36,6 @@ def test_property_is_list(self): def test_property_is_prohibited(self): self.assertFalse(Restrictions().is_prohibited) - self.assertTrue(Restrictions(prohibited=True).is_prohibited) self.assertTrue(Restrictions(max_occurs=0).is_prohibited) def test_merge(self): @@ -71,19 +68,22 @@ def test_asdict(self): "max_exclusive": "1", "max_inclusive": "1.1", "max_length": 10, - "max_occurs": 1, "min_exclusive": "1.1", "min_inclusive": "1", "min_length": 1, - "min_occurs": 1, "nillable": True, "pattern": "[A-Z]", - "required": True, "total_digits": 333, "white_space": "collapse", } self.assertEqual(expected, self.restrictions.asdict()) + self.restrictions.nillable = None + + del expected["nillable"] + expected["required"] = True + self.assertEqual(expected, self.restrictions.asdict()) + def test_asdict_with_types(self): expected = { "explicit_timezone": "+1", @@ -92,14 +92,11 @@ def test_asdict_with_types(self): "max_exclusive": 1.0, # str -> float "max_inclusive": 1.1, # str -> float "max_length": 10, - "max_occurs": 1, "min_exclusive": 1.1, # str -> float "min_inclusive": 1.0, # str -> float "min_length": 1, - "min_occurs": 1, "nillable": True, "pattern": "[A-Z]", - "required": True, "total_digits": 333, "white_space": "collapse", } @@ -121,24 +118,3 @@ def test_clone(self): self.assertEqual(clone, restrictions) self.assertIsNot(clone, restrictions) - - def test_is_compatible(self): - clone = self.restrictions.clone() - self.assertTrue(self.restrictions.is_compatible(clone)) - - clone.max_length = 10 - self.assertTrue(self.restrictions.is_compatible(clone)) - - clone.required = not self.restrictions.required - clone.nillable = not self.restrictions.nillable - clone.tokens = not self.restrictions.tokens - self.assertFalse(self.restrictions.is_compatible(clone)) - - clone.required = not clone.required - self.assertFalse(self.restrictions.is_compatible(clone)) - - clone.nillable = not clone.nillable - self.assertFalse(self.restrictions.is_compatible(clone)) - - clone.tokens = not clone.tokens - self.assertTrue(self.restrictions.is_compatible(clone)) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index 4c4a248a1..b364b26ad 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -54,7 +54,6 @@ def test_initialize(self): ], 20: [ "AttributeEffectiveChoiceHandler", - "AttributeRestrictionsHandler", "AttributeDefaultValueHandler", ], 30: [ diff --git a/tests/fixtures/artists/metadata.py b/tests/fixtures/artists/metadata.py index 26e249d25..8f8035ed2 100644 --- a/tests/fixtures/artists/metadata.py +++ b/tests/fixtures/artists/metadata.py @@ -45,9 +45,6 @@ class Meta: ) value: Optional[str] = field( default=None, - metadata={ - "required": True, - } ) @@ -92,9 +89,6 @@ class Meta: ) value: Optional[str] = field( default=None, - metadata={ - "required": True, - } ) diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index 227a13abe..7e69492d5 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -267,11 +267,29 @@ def test_field_metadata_restrictions(self): attr.restrictions.min_occurs = 1 attr.restrictions.max_occurs = 2 attr.restrictions.max_inclusive = "2" - attr.restrictions.required = False expected = {"min_occurs": 1, "max_occurs": 2, "max_inclusive": 2} self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + attr.restrictions.min_occurs = 1 + attr.restrictions.max_occurs = 1 + expected = {"required": True, "max_inclusive": 2} + self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + + attr.restrictions.nillable = True + expected = {"nillable": True, "max_inclusive": 2} + self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + + attr.default = "foo" + attr.restrictions.nillable = False + expected = {"max_inclusive": 2} + self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + + attr.default = None + attr.restrictions.tokens = True + expected = {"max_inclusive": 2, "tokens": True} + self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + def test_field_metadata_mixed(self): attr = AttrFactory.element(mixed=True) expected = {"mixed": True, "name": "attr_B", "type": "Element"} diff --git a/tests/models/xsd/test_attribute.py b/tests/models/xsd/test_attribute.py index e39534091..44260a49e 100644 --- a/tests/models/xsd/test_attribute.py +++ b/tests/models/xsd/test_attribute.py @@ -41,14 +41,21 @@ def test_property_real_name(self): def test_get_restrictions(self): obj = Attribute() - self.assertEqual({}, obj.get_restrictions()) + self.assertEqual({"max_occurs": 1, "min_occurs": 0}, obj.get_restrictions()) + obj.default = "foo" + self.assertEqual({"max_occurs": 1, "min_occurs": 1}, obj.get_restrictions()) + obj.default = None + obj.fixed = "foo" + self.assertEqual({"max_occurs": 1, "min_occurs": 1}, obj.get_restrictions()) + + obj.fixed = None obj.use = UseType.REQUIRED - expected = {"required": True} + expected = {"max_occurs": 1, "min_occurs": 1} self.assertEqual(expected, obj.get_restrictions()) obj.use = UseType.PROHIBITED - expected = {"prohibited": True} + expected = {"max_occurs": 0, "min_occurs": 0} self.assertEqual(expected, obj.get_restrictions()) obj.simple_type = SimpleType(restriction=Restriction(length=Length(value=1))) diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index d869d62ec..6e0aabf24 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -14,7 +14,6 @@ from xsdata.codegen.handlers import AttributeMixedContentHandler from xsdata.codegen.handlers import AttributeNameConflictHandler from xsdata.codegen.handlers import AttributeOverridesHandler -from xsdata.codegen.handlers import AttributeRestrictionsHandler from xsdata.codegen.handlers import AttributeSubstitutionHandler from xsdata.codegen.handlers import AttributeTypeHandler from xsdata.codegen.handlers import ClassDesignateHandler @@ -63,7 +62,6 @@ def __init__(self, config: GeneratorConfig): ], Steps.SANITIZE: [ AttributeEffectiveChoiceHandler(), - AttributeRestrictionsHandler(), AttributeDefaultValueHandler(self), ], Steps.RESOLVE: [ diff --git a/xsdata/codegen/handlers/__init__.py b/xsdata/codegen/handlers/__init__.py index 97670584a..9a63de8a0 100644 --- a/xsdata/codegen/handlers/__init__.py +++ b/xsdata/codegen/handlers/__init__.py @@ -7,7 +7,6 @@ from .attribute_mixed_content import AttributeMixedContentHandler from .attribute_name_conflict import AttributeNameConflictHandler from .attribute_overrides import AttributeOverridesHandler -from .attribute_restrictions import AttributeRestrictionsHandler from .attribute_substitution import AttributeSubstitutionHandler from .attribute_type import AttributeTypeHandler from .class_designate import ClassDesignateHandler @@ -26,7 +25,6 @@ "AttributeMixedContentHandler", "AttributeNameConflictHandler", "AttributeOverridesHandler", - "AttributeRestrictionsHandler", "AttributeSubstitutionHandler", "AttributeTypeHandler", "ClassDesignateHandler", diff --git a/xsdata/codegen/handlers/attribute_default_value.py b/xsdata/codegen/handlers/attribute_default_value.py index 19053297f..c42ff57b4 100644 --- a/xsdata/codegen/handlers/attribute_default_value.py +++ b/xsdata/codegen/handlers/attribute_default_value.py @@ -33,7 +33,7 @@ def process_attribute(self, target: Class, attr: Attr): if attr.is_enumeration: return - if attr.is_optional or attr.is_xsi_type: + if attr.is_optional or attr.is_xsi_type or attr.is_list: attr.fixed = False attr.default = None diff --git a/xsdata/codegen/handlers/attribute_overrides.py b/xsdata/codegen/handlers/attribute_overrides.py index 521129a29..83293b715 100644 --- a/xsdata/codegen/handlers/attribute_overrides.py +++ b/xsdata/codegen/handlers/attribute_overrides.py @@ -2,6 +2,7 @@ from operator import attrgetter from typing import Dict from typing import List +from typing import Optional from xsdata.codegen.mixins import RelativeHandlerInterface from xsdata.codegen.models import Attr @@ -47,12 +48,17 @@ def validate_override(cls, target: Class, attr: Attr, source_attr: Attr): if ( attr.default == source_attr.default - and attr.fixed == source_attr.fixed - and attr.mixed == source_attr.mixed - and attr.restrictions.is_compatible(source_attr.restrictions) + and bool_eq(attr.fixed, source_attr.fixed) + and bool_eq(attr.mixed, source_attr.mixed) + and bool_eq(attr.restrictions.tokens, source_attr.restrictions.tokens) + and bool_eq(attr.restrictions.nillable, source_attr.restrictions.nillable) ): ClassUtils.remove_attribute(target, attr) @classmethod def resolve_conflict(cls, attr: Attr, source_attr: Attr): ClassUtils.rename_attribute_by_preference(attr, source_attr) + + +def bool_eq(a: Optional[bool], b: Optional[bool]) -> bool: + return bool(a) is bool(b) diff --git a/xsdata/codegen/handlers/attribute_restrictions.py b/xsdata/codegen/handlers/attribute_restrictions.py deleted file mode 100644 index f4dfe125a..000000000 --- a/xsdata/codegen/handlers/attribute_restrictions.py +++ /dev/null @@ -1,47 +0,0 @@ -from xsdata.codegen.mixins import HandlerInterface -from xsdata.codegen.models import Attr -from xsdata.codegen.models import Class - - -class AttributeRestrictionsHandler(HandlerInterface): - """Sanitize attributes restrictions.""" - - __slots__ = () - - def process(self, target: Class): - for attr in target.attrs: - self.reset_occurrences(attr) - - @classmethod - def reset_occurrences(cls, attr: Attr): - """Sanitize attribute required flag by comparing the min/max - occurrences restrictions.""" - restrictions = attr.restrictions - min_occurs = restrictions.min_occurs or 0 - max_occurs = restrictions.max_occurs or 0 - - if attr.is_attribute: - restrictions.min_occurs = None - restrictions.max_occurs = None - elif attr.is_tokens: - restrictions.required = None - if max_occurs <= 1: - restrictions.min_occurs = None - restrictions.max_occurs = None - elif attr.xml_type is None or min_occurs == max_occurs == 1: - restrictions.required = True - restrictions.min_occurs = None - restrictions.max_occurs = None - elif min_occurs == 0 and max_occurs < 2: - restrictions.required = None - restrictions.min_occurs = None - restrictions.max_occurs = None - attr.default = None - attr.fixed = False - else: # max_occurs > 1 - restrictions.min_occurs = min_occurs - restrictions.required = None - attr.fixed = False - - if attr.default or attr.fixed or attr.restrictions.nillable: - restrictions.required = None diff --git a/xsdata/codegen/handlers/attribute_type.py b/xsdata/codegen/handlers/attribute_type.py index 5774ac18d..cf1047c5c 100644 --- a/xsdata/codegen/handlers/attribute_type.py +++ b/xsdata/codegen/handlers/attribute_type.py @@ -156,12 +156,18 @@ def copy_attribute_properties( restrictions = source_attr.restrictions.clone() restrictions.merge(attr.restrictions) - attr.restrictions = restrictions - attr.help = attr.help or source_attr.help + + if attr.is_attribute: + # Attributes maintain their occurrences no matter what! + restrictions.min_occurs = attr.restrictions.min_occurs + restrictions.max_occurs = attr.restrictions.max_occurs if source.nillable: restrictions.nillable = True + attr.restrictions = restrictions + attr.help = attr.help or source_attr.help + def set_circular_flag(self, source: Class, target: Class, attr_type: AttrType): """Update circular reference flag.""" attr_type.reference = id(source) diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 2bf71201f..c3adefab0 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -37,8 +37,6 @@ class Restrictions: """ Model representation of a dataclass field validation and type metadata. - :param required: - :param prohibited: :param min_occurs: :param max_occurs: :param min_exclusive: @@ -60,8 +58,6 @@ class Restrictions: :param choice: """ - required: Optional[bool] = field(default=None) - prohibited: Optional[bool] = field(default=None) min_occurs: Optional[int] = field(default=None) max_occurs: Optional[int] = field(default=None) min_exclusive: Optional[str] = field(default=None) @@ -94,17 +90,7 @@ def is_optional(self) -> bool: @property def is_prohibited(self) -> bool: - return self.prohibited or self.max_occurs == 0 - - def is_compatible(self, other: "Restrictions") -> bool: - def bool_eq(a: Optional[bool], b: Optional[bool]) -> bool: - return bool(a) is bool(b) - - return ( - bool_eq(self.required, other.required) - and bool_eq(self.nillable, other.nillable) - and bool_eq(self.tokens, other.tokens) - ) + return self.max_occurs == 0 def merge(self, source: "Restrictions"): """Update properties from another instance.""" @@ -134,8 +120,6 @@ def merge(self, source: "Restrictions"): def update(self, source: "Restrictions"): keys = ( - "required", - "prohibited", "min_exclusive", "min_inclusive", "min_length", @@ -166,14 +150,19 @@ def asdict(self, types: Optional[List[Type]] = None) -> Dict: result = {} sorted_types = converter.sort_types(types) if types else [] + if self.is_list: + if self.min_occurs is not None and self.min_occurs > 0: + result["min_occurs"] = self.min_occurs + if self.max_occurs is not None and self.max_occurs < sys.maxsize: + result["max_occurs"] = self.max_occurs + elif self.min_occurs == self.max_occurs == 1 and not self.nillable: + result["required"] = True + for key, value in asdict(self).items(): - if value is None or key == "choice": - continue - elif key == "max_occurs" and value >= sys.maxsize: + if value is None or key in ("choice", "min_occurs", "max_occurs"): continue - elif key == "min_occurs" and value == 0: - continue - elif key.endswith("clusive") and types: + + if key.endswith("clusive") and types: value = converter.deserialize(value, sorted_types) result[key] = value diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index f1df93baa..ca3952e95 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -218,6 +218,10 @@ def field_metadata( namespace = attr.namespace restrictions = attr.restrictions.asdict(attr.native_types) + + if attr.default or attr.is_factory: + restrictions.pop("required", None) + metadata = { "name": name, "type": attr.xml_type, diff --git a/xsdata/models/xsd.py b/xsdata/models/xsd.py index a535f38e9..297786604 100644 --- a/xsdata/models/xsd.py +++ b/xsdata/models/xsd.py @@ -331,10 +331,16 @@ def attr_types(self) -> Iterator[str]: def get_restrictions(self) -> Dict[str, Anything]: restrictions = {} + + if self.default or self.fixed: + self.use = UseType.REQUIRED + if self.use == UseType.REQUIRED: - restrictions.update({"required": True}) + restrictions.update({"min_occurs": 1, "max_occurs": 1}) elif self.use == UseType.PROHIBITED: - restrictions.update({"prohibited": True}) + restrictions.update({"max_occurs": 0, "min_occurs": 0}) + else: + restrictions.update({"max_occurs": 1, "min_occurs": 0}) if self.simple_type: restrictions.update(self.simple_type.get_restrictions()) From 4a5245a1fad9da75b62c358083c26ec002e50bbe Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 22 Jun 2021 22:48:39 +0300 Subject: [PATCH 7/7] Protect against useless xsi:type --- tests/formats/dataclass/test_context.py | 6 +++--- xsdata/formats/dataclass/context.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/formats/dataclass/test_context.py b/tests/formats/dataclass/test_context.py index 5b8ff04f2..106b438b7 100644 --- a/tests/formats/dataclass/test_context.py +++ b/tests/formats/dataclass/test_context.py @@ -84,12 +84,12 @@ def test_find_subclass(self): a = make_dataclass("A", fields=[]) b = make_dataclass("B", fields=[], bases=(a,)) c = make_dataclass("C", fields=[], bases=(a,)) - other = make_dataclass("Other", fields=[]) + other = make_dataclass("Other", fields=[]) # Included in the locals self.assertEqual(b, self.ctx.find_subclass(a, "B")) self.assertEqual(b, self.ctx.find_subclass(c, "B")) - self.assertEqual(a, self.ctx.find_subclass(b, "A")) - self.assertEqual(a, self.ctx.find_subclass(c, "A")) + self.assertIsNone(self.ctx.find_subclass(b, "A")) + self.assertIsNone(self.ctx.find_subclass(c, "A")) self.assertIsNone(self.ctx.find_subclass(c, "Unknown")) self.assertIsNone(self.ctx.find_subclass(c, "Other")) diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index d1a8e0e6b..ac6aa9fab 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -148,6 +148,13 @@ def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]: types: List[Type] = self.find_types(qname) for tp in types: + + # Why would an xml node with have an xsi:type that points + # to parent class is beyond me but it happens, let's protect + # against that scenario + if issubclass(clazz, tp): + continue + for tp_mro in tp.__mro__: if tp_mro is not object and tp_mro in clazz.__mro__: return tp