Skip to content

Commit

Permalink
Merge duplicate global types earlier
Browse files Browse the repository at this point in the history
Note:
This also reverses the previous behavior
tha favored the element and flattened the
complex type. The reason being complex
types can be used as base classes.
  • Loading branch information
tefra committed Feb 25, 2021
1 parent 09984ea commit 6fe4c5f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
52 changes: 36 additions & 16 deletions tests/codegen/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def setUp(self):
self.container = ClassContainer()
self.validator = ClassValidator(container=self.container)

@mock.patch.object(ClassValidator, "mark_strict_types")
@mock.patch.object(ClassValidator, "merge_dummy_types")
@mock.patch.object(ClassValidator, "handle_duplicate_types")
@mock.patch.object(ClassValidator, "remove_invalid_classes")
def test_process(
self,
mock_remove_invalid_classes,
mock_handle_duplicate_types,
mock_mark_strict_types,
mock_merge_dummy_types,
):
first = ClassFactory.create()
second = first.clone()
Expand All @@ -37,7 +37,7 @@ def test_process(

mock_remove_invalid_classes.assert_called_once_with([first, second])
mock_handle_duplicate_types.assert_called_once_with([first, second])
mock_mark_strict_types.assert_called_once_with([first, second])
mock_merge_dummy_types.assert_called_once_with([first, second])

def test_remove_invalid_classes(self):
first = ClassFactory.create(
Expand Down Expand Up @@ -98,22 +98,42 @@ def test_handle_duplicate_types_with_redefined_type(
[mock.call(two, one), mock.call(three, one)]
)

def test_mark_strict_types(self):
one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT)
def test_merge_dummy_types(self):
one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT, namespace="a", help="b")
two = ClassFactory.create(qname="foo", tag=Tag.COMPLEX_TYPE)
three = ClassFactory.create(qname="foo", tag=Tag.SIMPLE_TYPE)

self.validator.mark_strict_types([one, two, three])

self.assertFalse(one.strict_type) # Is an element
self.assertTrue(two.strict_type) # Marked as abstract
self.assertFalse(three.strict_type) # Is common

four = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE)
five = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
self.validator.mark_strict_types([four, five])
self.assertFalse(four.strict_type) # No element in group
self.assertFalse(five.strict_type) # No element in group
classes = [one, two, three]
self.validator.merge_dummy_types(classes)
self.assertEqual(3, len(classes))

classes = [one, three]
self.validator.merge_dummy_types(classes)
self.assertEqual(2, len(classes))

classes = [two, three]
self.validator.merge_dummy_types(classes)
self.assertEqual(2, len(classes))

classes = [one, two, three]
one.attrs.append(AttrFactory.create)
one.extensions.append(ExtensionFactory.reference(two.qname))

self.validator.merge_dummy_types(classes)
self.assertEqual(3, len(classes))

one.attrs.clear()
one.extensions.append(ExtensionFactory.reference("foo"))
self.validator.merge_dummy_types(classes)
self.assertEqual(3, len(classes))

one.extensions.pop()
self.validator.merge_dummy_types(classes)
self.assertEqual(2, len(classes))
self.assertIn(two, classes)
self.assertIn(three, classes)
self.assertEqual(one.namespace, two.namespace)
self.assertEqual(one.help, two.help)

@mock.patch.object(ClassUtils, "copy_extensions")
@mock.patch.object(ClassUtils, "copy_attributes")
Expand Down
47 changes: 32 additions & 15 deletions xsdata/codegen/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xsdata.codegen.models import Extension
from xsdata.codegen.utils import ClassUtils
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils.collections import group_by


Expand All @@ -23,9 +24,9 @@ def process(self):
Remove if possible classes with the same qualified name.
Steps:
1. Remove classes with missing extension type.
2. Handle duplicate types.
3. Mark strict types.
1. Remove invalid classes
2. Handle duplicate types
3. Merge dummy types
"""
for classes in self.container.data.values():

Expand All @@ -36,7 +37,7 @@ def process(self):
self.handle_duplicate_types(classes)

if len(classes) > 1:
self.mark_strict_types(classes)
self.merge_dummy_types(classes)

def remove_invalid_classes(self, classes: List[Class]):
"""Remove from the given class list any class with missing extension
Expand Down Expand Up @@ -124,14 +125,30 @@ def find_circular_group(cls, target: Class) -> Optional[Attr]:
return None

@classmethod
def mark_strict_types(cls, classes: List[Class]):
"""If there is a class derived from xs:element update all
xs:complexTypes derived classes as strict types."""

try:
element = next(obj for obj in classes if obj.is_element)
for obj in classes:
if obj is not element and obj.is_complex:
obj.strict_type = True
except StopIteration:
pass
def merge_dummy_types(cls, classes: List[Class]):
"""
Merge parent-child global types.
Conditions
1. One of them is derived from xs:element
2. One of them is derived from xs:complexType
3. The xs:element is a subclass of the xs:complexType
4. The xs:element has no attributes (This can't happen in a valid schema)
"""

el = collections.first(x for x in classes if x.tag == Tag.ELEMENT)
ct = collections.first(x for x in classes if x.tag == Tag.COMPLEX_TYPE)

if (
el is None
or ct is None
or el is ct
or el.attrs
or len(el.extensions) != 1
or el.extensions[0].type.qname != el.qname
):
return

ct.namespace = el.namespace or ct.namespace
ct.help = el.help or ct.help
classes.remove(el)

0 comments on commit 6fe4c5f

Please sign in to comment.