Skip to content

Commit

Permalink
Merge duplicate global types earlier
Browse files Browse the repository at this point in the history
Note:
This also reverses the previous behavior
tha favored the element and flattened the
complex type. The reason being complex
types can be used as base classes.
  • Loading branch information
tefra committed Feb 25, 2021
1 parent 09984ea commit 4fc1ed8
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 53 deletions.
4 changes: 0 additions & 4 deletions tests/codegen/handlers/test_class_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,6 @@ def test_should_flatten_extension(self):

self.assertFalse(self.processor.should_flatten_extension(source, target))

# Forced flattened
source.strict_type = True
self.assertTrue(self.processor.should_flatten_extension(source, target))

# Source has suffix attr and target has its own attrs
source = ClassFactory.elements(1)
source.attrs[0].index = sys.maxsize
Expand Down
3 changes: 0 additions & 3 deletions tests/codegen/models/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,3 @@ def test_property_should_generate(self):

obj = ClassFactory.create(tag=Tag.SIMPLE_TYPE)
self.assertFalse(obj.should_generate)

obj = ClassFactory.create(tag=Tag.BINDING_MESSAGE, strict_type=True)
self.assertFalse(obj.should_generate)
52 changes: 36 additions & 16 deletions tests/codegen/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def setUp(self):
self.container = ClassContainer()
self.validator = ClassValidator(container=self.container)

@mock.patch.object(ClassValidator, "mark_strict_types")
@mock.patch.object(ClassValidator, "merge_global_types")
@mock.patch.object(ClassValidator, "handle_duplicate_types")
@mock.patch.object(ClassValidator, "remove_invalid_classes")
def test_process(
self,
mock_remove_invalid_classes,
mock_handle_duplicate_types,
mock_mark_strict_types,
mock_merge_global_types,
):
first = ClassFactory.create()
second = first.clone()
Expand All @@ -37,7 +37,7 @@ def test_process(

mock_remove_invalid_classes.assert_called_once_with([first, second])
mock_handle_duplicate_types.assert_called_once_with([first, second])
mock_mark_strict_types.assert_called_once_with([first, second])
mock_merge_global_types.assert_called_once_with([first, second])

def test_remove_invalid_classes(self):
first = ClassFactory.create(
Expand Down Expand Up @@ -98,22 +98,42 @@ def test_handle_duplicate_types_with_redefined_type(
[mock.call(two, one), mock.call(three, one)]
)

def test_mark_strict_types(self):
one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT)
def test_merge_global_types(self):
one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT, namespace="a", help="b")
two = ClassFactory.create(qname="foo", tag=Tag.COMPLEX_TYPE)
three = ClassFactory.create(qname="foo", tag=Tag.SIMPLE_TYPE)

self.validator.mark_strict_types([one, two, three])

self.assertFalse(one.strict_type) # Is an element
self.assertTrue(two.strict_type) # Marked as abstract
self.assertFalse(three.strict_type) # Is common

four = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE)
five = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
self.validator.mark_strict_types([four, five])
self.assertFalse(four.strict_type) # No element in group
self.assertFalse(five.strict_type) # No element in group
classes = [one, two, three]
self.validator.merge_global_types(classes)
self.assertEqual(3, len(classes))

classes = [one, three]
self.validator.merge_global_types(classes)
self.assertEqual(2, len(classes))

classes = [two, three]
self.validator.merge_global_types(classes)
self.assertEqual(2, len(classes))

classes = [one, two, three]
one.attrs.append(AttrFactory.create)
one.extensions.append(ExtensionFactory.reference(two.qname))

self.validator.merge_global_types(classes)
self.assertEqual(3, len(classes))

one.attrs.clear()
one.extensions.append(ExtensionFactory.reference("foo"))
self.validator.merge_global_types(classes)
self.assertEqual(3, len(classes))

one.extensions.pop()
self.validator.merge_global_types(classes)
self.assertEqual(2, len(classes))
self.assertIn(two, classes)
self.assertIn(three, classes)
self.assertEqual(one.namespace, two.namespace)
self.assertEqual(one.help, two.help)

@mock.patch.object(ClassUtils, "copy_extensions")
@mock.patch.object(ClassUtils, "copy_attributes")
Expand Down
2 changes: 0 additions & 2 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def create(
abstract=False,
mixed=False,
nillable=False,
strict_type=False,
help=None,
extensions=None,
substitutions=None,
Expand All @@ -96,7 +95,6 @@ def create(
abstract=abstract,
mixed=mixed,
nillable=nillable,
strict_type=strict_type,
tag=tag or random.choice(cls.tags),
extensions=extensions or [],
substitutions=substitutions or [],
Expand Down
14 changes: 6 additions & 8 deletions xsdata/codegen/handlers/class_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,15 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool:
Return whether the extension should be flattened because of rules.
Rules:
1. Source class is marked as a strict type
2. Source class is a simple type
3. Source class has a suffix attr and target has its own attrs
4. Target class has a suffix attr
5. Target restrictions parent attrs in different sequential order
6. Target restricts parent attr with a not matching type.
1. Source class is a simple type
2. Source class has a suffix attr and target has its own attrs
3. Target class has a suffix attr
4. Target restrictions parent attrs in different sequential order
5. Target restricts parent attr with a not matching type.
"""

if (
source.strict_type
or source.is_simple_type
source.is_simple_type
or target.has_suffix_attr
or (source.has_suffix_attr and target.attrs)
or not cls.validate_type_overrides(source, target)
Expand Down
5 changes: 0 additions & 5 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ class Class:
:param mixed:
:param abstract:
:param nillable:
:param strict_type:
:param status:
:param container:
:param package:
Expand All @@ -414,7 +413,6 @@ class Class:
mixed: bool = field(default=False)
abstract: bool = field(default=False)
nillable: bool = field(default=False)
strict_type: bool = field(default=False)
status: Status = field(default=Status.RAW)
container: Optional[str] = field(default=None)
package: Optional[str] = field(default=None)
Expand Down Expand Up @@ -487,9 +485,6 @@ def is_simple_type(self) -> bool:
@property
def should_generate(self) -> bool:
"""Return whether this instance should be generated."""
if self.strict_type:
return False

return (
self.tag
in (Tag.ELEMENT, Tag.BINDING_OPERATION, Tag.BINDING_MESSAGE, Tag.MESSAGE)
Expand Down
47 changes: 32 additions & 15 deletions xsdata/codegen/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xsdata.codegen.models import Extension
from xsdata.codegen.utils import ClassUtils
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils.collections import group_by


Expand All @@ -23,9 +24,9 @@ def process(self):
Remove if possible classes with the same qualified name.
Steps:
1. Remove classes with missing extension type.
2. Handle duplicate types.
3. Mark strict types.
1. Remove invalid classes
2. Handle duplicate types
3. Merge dummy types
"""
for classes in self.container.data.values():

Expand All @@ -36,7 +37,7 @@ def process(self):
self.handle_duplicate_types(classes)

if len(classes) > 1:
self.mark_strict_types(classes)
self.merge_global_types(classes)

def remove_invalid_classes(self, classes: List[Class]):
"""Remove from the given class list any class with missing extension
Expand Down Expand Up @@ -124,14 +125,30 @@ def find_circular_group(cls, target: Class) -> Optional[Attr]:
return None

@classmethod
def mark_strict_types(cls, classes: List[Class]):
"""If there is a class derived from xs:element update all
xs:complexTypes derived classes as strict types."""

try:
element = next(obj for obj in classes if obj.is_element)
for obj in classes:
if obj is not element and obj.is_complex:
obj.strict_type = True
except StopIteration:
pass
def merge_global_types(cls, classes: List[Class]):
"""
Merge parent-child global types.
Conditions
1. One of them is derived from xs:element
2. One of them is derived from xs:complexType
3. The xs:element is a subclass of the xs:complexType
4. The xs:element has no attributes (This can't happen in a valid schema)
"""

el = collections.first(x for x in classes if x.tag == Tag.ELEMENT)
ct = collections.first(x for x in classes if x.tag == Tag.COMPLEX_TYPE)

if (
el is None
or ct is None
or el is ct
or el.attrs
or len(el.extensions) != 1
or el.extensions[0].type.qname != el.qname
):
return

ct.namespace = el.namespace or ct.namespace
ct.help = el.help or ct.help
classes.remove(el)

0 comments on commit 4fc1ed8

Please sign in to comment.