Skip to content

Commit

Permalink
Generate complex types with name conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Feb 23, 2021
1 parent dde106c commit 0de6251
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 76 deletions.
1 change: 1 addition & 0 deletions tests/codegen/handlers/test_attribute_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def test_set_circular_flag(self, mock_is_circular_dependency):

self.processor.set_circular_flag(source, target, attr_type)
self.assertTrue(attr_type.circular)
self.assertEqual(id(source), attr_type.reference)

mock_is_circular_dependency.assert_called_once_with(source, target, set())

Expand Down
7 changes: 3 additions & 4 deletions tests/codegen/handlers/test_class_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def test_process_complex_extension_removes_extension(

mock_should_remove_extension.assert_called_once_with(source, target)
self.assertEqual(0, mock_copy_attributes.call_count)
self.assertEqual(0, extension.type.reference)

@mock.patch.object(ClassUtils, "copy_attributes")
@mock.patch.object(ClassExtensionHandler, "should_flatten_extension")
Expand All @@ -307,6 +308,7 @@ def test_process_complex_extension_copies_attributes(
source = ClassFactory.create()

self.processor.process_complex_extension(source, target, extension)
self.assertEqual(0, extension.type.reference)
mock_compare_attributes.assert_called_once_with(source, target)
mock_should_flatten_extension.assert_called_once_with(source, target, extension)

Expand All @@ -323,6 +325,7 @@ def test_process_complex_extension_ignores_extension(

self.processor.process_complex_extension(source, target, extension)
self.assertEqual(1, len(target.extensions))
self.assertEqual(id(source), extension.type.reference)

def test_find_dependency(self):
attr_type = AttrTypeFactory.create(qname="a")
Expand Down Expand Up @@ -364,10 +367,6 @@ def test_should_flatten_extension(self):

self.assertFalse(self.processor.should_flatten_extension(source, target))

# Forced flattened
source.strict_type = True
self.assertTrue(self.processor.should_flatten_extension(source, target))

# Source has suffix attr and target has its own attrs
source = ClassFactory.elements(1)
source.attrs[0].index = sys.maxsize
Expand Down
3 changes: 0 additions & 3 deletions tests/codegen/models/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,3 @@ def test_property_should_generate(self):

obj = ClassFactory.create(tag=Tag.SIMPLE_TYPE)
self.assertFalse(obj.should_generate)

obj = ClassFactory.create(tag=Tag.BINDING_MESSAGE, strict_type=True)
self.assertFalse(obj.should_generate)
21 changes: 11 additions & 10 deletions tests/codegen/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ def test_rename_class(self, mock_rename_class_dependencies):
self.assertEqual("_a", target.meta_name)

mock_rename_class_dependencies.assert_has_calls(
mock.call(item, "{foo}_a", "{foo}_a_3")
mock.call(item, id(target), "{foo}_a_3")
for item in self.sanitizer.container.iterate()
)

self.assertEqual([target], self.container.data["{foo}_a_3"])
self.assertEqual([], self.container.data["{foo}_a"])

def test_rename_class_dependencies(self):
attr_type = AttrTypeFactory.create("{foo}bar")
attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1)

target = ClassFactory.create(
extensions=[
Expand All @@ -428,29 +428,30 @@ def test_rename_class_dependencies(self):
],
)

self.sanitizer.rename_class_dependencies(target, "{foo}bar", "thug")
self.sanitizer.rename_class_dependencies(target, 1, "thug")
dependencies = set(target.dependencies())
self.assertNotIn("{foo}bar", dependencies)
self.assertIn("thug", dependencies)

def test_rename_attr_dependencies_with_default_enum(self):
attr_type = AttrTypeFactory.create("{foo}bar")
attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1)
target = ClassFactory.create(
attrs=[
AttrFactory.create(
types=[attr_type], default=f"@enum@{attr_type.qname}::member"
types=[attr_type],
default=f"@enum@{attr_type.qname}::member",
),
]
)

self.sanitizer.rename_class_dependencies(target, "{foo}bar", "thug")
self.sanitizer.rename_class_dependencies(target, 1, "thug")
dependencies = set(target.dependencies())
self.assertEqual("@enum@thug::member", target.attrs[0].default)
self.assertNotIn("{foo}bar", dependencies)
self.assertIn("thug", dependencies)

def test_rename_attr_dependencies_with_choices(self):
attr_type = AttrTypeFactory.create("{foo}bar")
attr_type = AttrTypeFactory.create(qname="foo", reference=1)
target = ClassFactory.create(
attrs=[
AttrFactory.create(
Expand All @@ -461,10 +462,10 @@ def test_rename_attr_dependencies_with_choices(self):
]
)

self.sanitizer.rename_class_dependencies(target, "{foo}bar", "thug")
self.sanitizer.rename_class_dependencies(target, 1, "bar")
dependencies = set(target.dependencies())
self.assertNotIn("{foo}bar", dependencies)
self.assertIn("thug", dependencies)
self.assertNotIn("foo", dependencies)
self.assertIn("bar", dependencies)

@mock.patch.object(ClassSanitizer, "group_fields")
def test_group_compound_fields(self, mock_group_fields):
Expand Down
20 changes: 0 additions & 20 deletions tests/codegen/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ def setUp(self):
self.container = ClassContainer()
self.validator = ClassValidator(container=self.container)

@mock.patch.object(ClassValidator, "mark_strict_types")
@mock.patch.object(ClassValidator, "handle_duplicate_types")
@mock.patch.object(ClassValidator, "remove_invalid_classes")
def test_process(
self,
mock_remove_invalid_classes,
mock_handle_duplicate_types,
mock_mark_strict_types,
):
first = ClassFactory.create()
second = first.clone()
Expand All @@ -37,7 +35,6 @@ def test_process(

mock_remove_invalid_classes.assert_called_once_with([first, second])
mock_handle_duplicate_types.assert_called_once_with([first, second])
mock_mark_strict_types.assert_called_once_with([first, second])

def test_remove_invalid_classes(self):
first = ClassFactory.create(
Expand Down Expand Up @@ -98,23 +95,6 @@ def test_handle_duplicate_types_with_redefined_type(
[mock.call(two, one), mock.call(three, one)]
)

def test_mark_strict_types(self):
one = ClassFactory.create(qname="foo", tag=Tag.ELEMENT)
two = ClassFactory.create(qname="foo", tag=Tag.COMPLEX_TYPE)
three = ClassFactory.create(qname="foo", tag=Tag.SIMPLE_TYPE)

self.validator.mark_strict_types([one, two, three])

self.assertFalse(one.strict_type) # Is an element
self.assertTrue(two.strict_type) # Marked as abstract
self.assertFalse(three.strict_type) # Is common

four = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE)
five = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
self.validator.mark_strict_types([four, five])
self.assertFalse(four.strict_type) # No element in group
self.assertFalse(five.strict_type) # No element in group

@mock.patch.object(ClassUtils, "copy_extensions")
@mock.patch.object(ClassUtils, "copy_attributes")
def test_merge_redefined_type_with_circular_extension(
Expand Down
4 changes: 2 additions & 2 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def create(
abstract=False,
mixed=False,
nillable=False,
strict_type=False,
help=None,
extensions=None,
substitutions=None,
Expand All @@ -96,7 +95,6 @@ def create(
abstract=abstract,
mixed=mixed,
nillable=nillable,
strict_type=strict_type,
tag=tag or random.choice(cls.tags),
extensions=extensions or [],
substitutions=substitutions or [],
Expand Down Expand Up @@ -173,6 +171,7 @@ def create(
native=False,
forward=False,
circular=False,
reference=0,
):
if not qname:
qname = build_qname("xsdata", f"attr_{cls.next_letter()}")
Expand All @@ -183,6 +182,7 @@ def create(
native=native,
circular=circular,
forward=forward,
reference=reference,
)

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/hello/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from tests.fixtures.hello.hello import (
HelloByeError,
HelloByeError1,
HelloError,
HelloError1,
HelloGetHelloAsString,
HelloGetHelloAsStringInput,
HelloGetHelloAsStringOutput,
Expand Down
20 changes: 16 additions & 4 deletions tests/fixtures/hello/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


@dataclass
class HelloByeError:
class HelloByeError1:
class Meta:
namespace = "http://hello/"
name = "HelloByeError"

message: Optional[str] = field(
default=None,
Expand All @@ -19,9 +19,9 @@ class Meta:


@dataclass
class HelloError:
class HelloError1:
class Meta:
namespace = "http://hello/"
name = "HelloError"

message: Optional[str] = field(
default=None,
Expand Down Expand Up @@ -61,6 +61,18 @@ class Meta:
)


@dataclass
class HelloByeError(HelloByeError1):
class Meta:
namespace = "http://hello/"


@dataclass
class HelloError(HelloError1):
class Meta:
namespace = "http://hello/"


@dataclass
class HelloGetHelloAsStringInput:
class Meta:
Expand Down
2 changes: 2 additions & 0 deletions xsdata/codegen/handlers/attribute_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def process_dependency_type(self, target: Class, attr: Attr, attr_type: AttrType
attr.restrictions.format = collections.first(
x.restrictions.format for x in source.attrs if x.restrictions.format
)
attr_type.reference = id(source)
else:
self.set_circular_flag(source, target, attr_type)

Expand Down Expand Up @@ -161,6 +162,7 @@ def copy_attribute_properties(

def set_circular_flag(self, source: Class, target: Class, attr_type: AttrType):
"""Update circular reference flag."""
attr_type.reference = id(source)
attr_type.circular = self.is_circular_dependency(source, target, set())

def is_circular_dependency(self, source: Class, target: Class, seen: Set) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions xsdata/codegen/handlers/class_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class or leave the extension alone.
elif cls.should_flatten_extension(source, target):
ClassUtils.copy_attributes(source, target, ext)
else:
ext.type.reference = id(source)
logger.debug("Ignore extension: %s", ext.type.name)

def find_dependency(self, attr_type: AttrType) -> Optional[Class]:
Expand Down Expand Up @@ -192,8 +193,7 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool:
"""

if (
source.strict_type
or source.is_simple_type
source.is_simple_type
or target.has_suffix_attr
or (source.has_suffix_attr and target.attrs)
or not cls.validate_type_overrides(source, target)
Expand Down
7 changes: 2 additions & 5 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,15 @@ class AttrType:
:param qname:
:param alias:
:param reference:
:param native:
:param forward:
:param circular:
"""

qname: str
alias: Optional[str] = field(default=None, compare=False)
reference: int = field(default=0, compare=False)
native: bool = field(default=False)
forward: bool = field(default=False)
circular: bool = field(default=False)
Expand Down Expand Up @@ -392,7 +394,6 @@ class Class:
:param mixed:
:param abstract:
:param nillable:
:param strict_type:
:param status:
:param container:
:param package:
Expand All @@ -414,7 +415,6 @@ class Class:
mixed: bool = field(default=False)
abstract: bool = field(default=False)
nillable: bool = field(default=False)
strict_type: bool = field(default=False)
status: Status = field(default=Status.RAW)
container: Optional[str] = field(default=None)
package: Optional[str] = field(default=None)
Expand Down Expand Up @@ -487,9 +487,6 @@ def is_simple_type(self) -> bool:
@property
def should_generate(self) -> bool:
"""Return whether this instance should be generated."""
if self.strict_type:
return False

return (
self.tag
in (Tag.ELEMENT, Tag.BINDING_OPERATION, Tag.BINDING_MESSAGE, Tag.MESSAGE)
Expand Down
20 changes: 11 additions & 9 deletions xsdata/codegen/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xsdata.models.enums import DataType
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils import text
from xsdata.utils.collections import first
from xsdata.utils.collections import group_by
from xsdata.utils.namespaces import build_qname
Expand Down Expand Up @@ -183,7 +184,7 @@ def rename_class(self, target: Class):
self.container.reset(target, qname)

for item in self.container.iterate():
self.rename_class_dependencies(item, qname, target.qname)
self.rename_class_dependencies(item, id(target), target.qname)

def next_qname(self, namespace: str, name: str) -> str:
"""Append the next available index number for the given namespace and
Expand All @@ -196,32 +197,33 @@ def next_qname(self, namespace: str, name: str) -> str:
if alnum(qname) not in reserved:
return qname

def rename_class_dependencies(self, target: Class, search: str, replace: str):
def rename_class_dependencies(self, target: Class, reference: int, replace: str):
"""Search and replace the old qualified attribute type name with the
new one if it exists in the target class attributes, extensions and
inner classes."""
for attr in target.attrs:
self.rename_attr_dependencies(attr, search, replace)
self.rename_attr_dependencies(attr, reference, replace)

for ext in target.extensions:
if ext.type.qname == search:
if ext.type.reference == reference:
ext.type.qname = replace

for inner in target.inner:
self.rename_class_dependencies(inner, search, replace)
self.rename_class_dependencies(inner, reference, replace)

def rename_attr_dependencies(self, attr: Attr, search: str, replace: str):
def rename_attr_dependencies(self, attr: Attr, reference: int, replace: str):
"""Search and replace the old qualified attribute type name with the
new one in the attr types, choices and default value."""
for attr_type in attr.types:
if attr_type.qname == search:
if attr_type.reference == reference:
attr_type.qname = replace

if isinstance(attr.default, str) and attr.default.startswith("@enum@"):
attr.default = attr.default.replace(search, replace)
member = text.suffix(attr.default, "::")
attr.default = f"@enum@{replace}::{member}"

for choice in attr.choices:
self.rename_attr_dependencies(choice, search, replace)
self.rename_attr_dependencies(choice, reference, replace)

def find_enum(self, attr_type: AttrType) -> Optional[Class]:
"""Find an enumeration class byte the attribute type."""
Expand Down
Loading

0 comments on commit 0de6251

Please sign in to comment.