diff --git a/tests/codegen/handlers/test_class_extension.py b/tests/codegen/handlers/test_class_extension.py index 21732feb6..d01c293c6 100644 --- a/tests/codegen/handlers/test_class_extension.py +++ b/tests/codegen/handlers/test_class_extension.py @@ -364,10 +364,6 @@ def test_should_flatten_extension(self): self.assertFalse(self.processor.should_flatten_extension(source, target)) - # Forced flattened - source.strict_type = True - self.assertTrue(self.processor.should_flatten_extension(source, target)) - # Source has suffix attr and target has its own attrs source = ClassFactory.elements(1) source.attrs[0].index = sys.maxsize diff --git a/tests/codegen/models/test_class.py b/tests/codegen/models/test_class.py index 55bdbb8c8..80fffa1c1 100644 --- a/tests/codegen/models/test_class.py +++ b/tests/codegen/models/test_class.py @@ -152,6 +152,3 @@ def test_property_should_generate(self): obj = ClassFactory.create(tag=Tag.SIMPLE_TYPE) self.assertFalse(obj.should_generate) - - obj = ClassFactory.create(tag=Tag.BINDING_MESSAGE, strict_type=True) - self.assertFalse(obj.should_generate) diff --git a/tests/codegen/test_validator.py b/tests/codegen/test_validator.py index 2d9e6cf84..28a8ebf95 100644 --- a/tests/codegen/test_validator.py +++ b/tests/codegen/test_validator.py @@ -19,14 +19,14 @@ def setUp(self): self.container = ClassContainer() self.validator = ClassValidator(container=self.container) - @mock.patch.object(ClassValidator, "mark_strict_types") + @mock.patch.object(ClassValidator, "merge_global_types") @mock.patch.object(ClassValidator, "handle_duplicate_types") @mock.patch.object(ClassValidator, "remove_invalid_classes") def test_process( self, mock_remove_invalid_classes, mock_handle_duplicate_types, - mock_mark_strict_types, + mock_merge_global_types, ): first = ClassFactory.create() second = first.clone() @@ -37,7 +37,7 @@ def test_process( mock_remove_invalid_classes.assert_called_once_with([first, second]) mock_handle_duplicate_types.assert_called_once_with([first, second]) - mock_mark_strict_types.assert_called_once_with([first, second]) + mock_merge_global_types.assert_called_once_with([first, second]) def test_remove_invalid_classes(self): first = ClassFactory.create( @@ -98,22 +98,42 @@ def test_handle_duplicate_types_with_redefined_type( [mock.call(two, one), mock.call(three, one)] ) - def test_mark_strict_types(self): - one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT) + def test_merge_global_types(self): + one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT, namespace="a", help="b") two = ClassFactory.create(qname="foo", tag=Tag.COMPLEX_TYPE) three = ClassFactory.create(qname="foo", tag=Tag.SIMPLE_TYPE) - self.validator.mark_strict_types([one, two, three]) - - self.assertFalse(one.strict_type) # Is an element - self.assertTrue(two.strict_type) # Marked as abstract - self.assertFalse(three.strict_type) # Is common - - four = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE) - five = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) - self.validator.mark_strict_types([four, five]) - self.assertFalse(four.strict_type) # No element in group - self.assertFalse(five.strict_type) # No element in group + classes = [one, two, three] + self.validator.merge_global_types(classes) + self.assertEqual(3, len(classes)) + + classes = [one, three] + self.validator.merge_global_types(classes) + self.assertEqual(2, len(classes)) + + classes = [two, three] + self.validator.merge_global_types(classes) + self.assertEqual(2, len(classes)) + + classes = [one, two, three] + one.attrs.append(AttrFactory.create) + one.extensions.append(ExtensionFactory.reference(two.qname)) + + self.validator.merge_global_types(classes) + self.assertEqual(3, len(classes)) + + one.attrs.clear() + one.extensions.append(ExtensionFactory.reference("foo")) + self.validator.merge_global_types(classes) + self.assertEqual(3, len(classes)) + + one.extensions.pop() + self.validator.merge_global_types(classes) + self.assertEqual(2, len(classes)) + self.assertIn(two, classes) + self.assertIn(three, classes) + self.assertEqual(one.namespace, two.namespace) + self.assertEqual(one.help, two.help) @mock.patch.object(ClassUtils, "copy_extensions") @mock.patch.object(ClassUtils, "copy_attributes") diff --git a/tests/factories.py b/tests/factories.py index 527403ecb..c1fc3c094 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -72,7 +72,6 @@ def create( abstract=False, mixed=False, nillable=False, - strict_type=False, help=None, extensions=None, substitutions=None, @@ -96,7 +95,6 @@ def create( abstract=abstract, mixed=mixed, nillable=nillable, - strict_type=strict_type, tag=tag or random.choice(cls.tags), extensions=extensions or [], substitutions=substitutions or [], diff --git a/xsdata/codegen/handlers/class_extension.py b/xsdata/codegen/handlers/class_extension.py index e61c0a7df..aaf228d4c 100644 --- a/xsdata/codegen/handlers/class_extension.py +++ b/xsdata/codegen/handlers/class_extension.py @@ -183,17 +183,15 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool: Return whether the extension should be flattened because of rules. Rules: - 1. Source class is marked as a strict type - 2. Source class is a simple type - 3. Source class has a suffix attr and target has its own attrs - 4. Target class has a suffix attr - 5. Target restrictions parent attrs in different sequential order - 6. Target restricts parent attr with a not matching type. + 1. Source class is a simple type + 2. Source class has a suffix attr and target has its own attrs + 3. Target class has a suffix attr + 4. Target restrictions parent attrs in different sequential order + 5. Target restricts parent attr with a not matching type. """ if ( - source.strict_type - or source.is_simple_type + source.is_simple_type or target.has_suffix_attr or (source.has_suffix_attr and target.attrs) or not cls.validate_type_overrides(source, target) diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index fa222aa5a..94e60c89e 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -392,7 +392,6 @@ class Class: :param mixed: :param abstract: :param nillable: - :param strict_type: :param status: :param container: :param package: @@ -414,7 +413,6 @@ class Class: mixed: bool = field(default=False) abstract: bool = field(default=False) nillable: bool = field(default=False) - strict_type: bool = field(default=False) status: Status = field(default=Status.RAW) container: Optional[str] = field(default=None) package: Optional[str] = field(default=None) @@ -487,9 +485,6 @@ def is_simple_type(self) -> bool: @property def should_generate(self) -> bool: """Return whether this instance should be generated.""" - if self.strict_type: - return False - return ( self.tag in (Tag.ELEMENT, Tag.BINDING_OPERATION, Tag.BINDING_MESSAGE, Tag.MESSAGE) diff --git a/xsdata/codegen/validator.py b/xsdata/codegen/validator.py index 0b6fa68b7..fb9e4b110 100644 --- a/xsdata/codegen/validator.py +++ b/xsdata/codegen/validator.py @@ -8,6 +8,7 @@ from xsdata.codegen.models import Extension from xsdata.codegen.utils import ClassUtils from xsdata.models.enums import Tag +from xsdata.utils import collections from xsdata.utils.collections import group_by @@ -23,9 +24,9 @@ def process(self): Remove if possible classes with the same qualified name. Steps: - 1. Remove classes with missing extension type. - 2. Handle duplicate types. - 3. Mark strict types. + 1. Remove invalid classes + 2. Handle duplicate types + 3. Merge dummy types """ for classes in self.container.data.values(): @@ -36,7 +37,7 @@ def process(self): self.handle_duplicate_types(classes) if len(classes) > 1: - self.mark_strict_types(classes) + self.merge_global_types(classes) def remove_invalid_classes(self, classes: List[Class]): """Remove from the given class list any class with missing extension @@ -124,14 +125,30 @@ def find_circular_group(cls, target: Class) -> Optional[Attr]: return None @classmethod - def mark_strict_types(cls, classes: List[Class]): - """If there is a class derived from xs:element update all - xs:complexTypes derived classes as strict types.""" - - try: - element = next(obj for obj in classes if obj.is_element) - for obj in classes: - if obj is not element and obj.is_complex: - obj.strict_type = True - except StopIteration: - pass + def merge_global_types(cls, classes: List[Class]): + """ + Merge parent-child global types. + + Conditions + 1. One of them is derived from xs:element + 2. One of them is derived from xs:complexType + 3. The xs:element is a subclass of the xs:complexType + 4. The xs:element has no attributes (This can't happen in a valid schema) + """ + + el = collections.first(x for x in classes if x.tag == Tag.ELEMENT) + ct = collections.first(x for x in classes if x.tag == Tag.COMPLEX_TYPE) + + if ( + el is None + or ct is None + or el is ct + or el.attrs + or len(el.extensions) != 1 + or el.extensions[0].type.qname != el.qname + ): + return + + ct.namespace = el.namespace or ct.namespace + ct.help = el.help or ct.help + classes.remove(el)