Skip to content

Commit

Permalink
Stricter duplicate class/attr name comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Feb 21, 2021
1 parent 99fd2c3 commit 7153e6e
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 64 deletions.
4 changes: 2 additions & 2 deletions tests/codegen/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_resolve_imports(
):
class_life = ClassFactory.create(qname="life")
import_names = [
"foo", # cool
"foo_1", # cool
"bar", # cool
"{another}foo", # another foo
"{another}foo1", # another foo
"{thug}life", # life class exists add alias
"{common}type", # type class doesn't exist add just the name
]
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/test_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase

from xsdata.utils.text import alnum
from xsdata.utils.text import camel_case
from xsdata.utils.text import capitalize
from xsdata.utils.text import mixed_case
Expand Down Expand Up @@ -99,3 +100,10 @@ def test_split_words(self):
self.assertEqual(["user"], split_words("__user"))
self.assertEqual(["TMessage", "DB"], split_words("TMessageDB"))
self.assertEqual(["GLOBAL", "REF"], split_words("GLOBAL-REF"))

def test_alnum(self):
self.assertEqual("foo1", alnum("foo 1"))
self.assertEqual("foo1", alnum(" foo_1 "))
self.assertEqual("foo1", alnum("\tfoo*1"))
self.assertEqual("foo1", alnum(" foo*1"))
self.assertEqual("βιβλίο1", alnum(" βιβλίο*1"))
12 changes: 6 additions & 6 deletions xsdata/codegen/mappers/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from xsdata.utils import text
from xsdata.utils.collections import first
from xsdata.utils.namespaces import build_qname
from xsdata.utils.namespaces import split_qname
from xsdata.utils.namespaces import local_name


class DefinitionsMapper:
Expand Down Expand Up @@ -216,10 +216,10 @@ def build_envelope_class(

for ext in binding_message.extended_elements:
assert ext.qname is not None
local_name = split_qname(ext.qname)[1].title()
inner = cls.build_inner_class(target, local_name)
class_name = local_name(ext.qname).title()
inner = cls.build_inner_class(target, class_name)

if style == "rpc" and local_name == "Body":
if style == "rpc" and class_name == "Body":
namespace = ext.attributes.get("namespace")
attrs = cls.map_port_type_message(port_type_message, namespace)
else:
Expand Down Expand Up @@ -300,7 +300,7 @@ def map_binding_message_parts(
parts.extend(extended.attributes["parts"].split())

if "message" in extended.attributes:
message_name = split_qname(extended.attributes["message"])[1]
message_name = local_name(extended.attributes["message"])
else:
message_name = text.suffix(message)

Expand Down Expand Up @@ -352,7 +352,7 @@ def operation_namespace(cls, config: Dict) -> Optional[str]:
def attributes(cls, elements: Iterator[AnyElement]) -> Dict:
"""Return all attributes from all extended elements as a dictionary."""
return {
split_qname(qname)[1]: value
local_name(qname): value
for element in elements
if isinstance(element, AnyElement)
for qname, value in element.attributes.items()
Expand Down
9 changes: 5 additions & 4 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from xsdata.models.enums import Tag
from xsdata.models.mixins import ElementBase
from xsdata.utils.namespaces import build_qname
from xsdata.utils.namespaces import split_qname
from xsdata.utils.namespaces import local_name
from xsdata.utils.namespaces import target_uri

xml_type_map = {
Tag.ANY: XmlType.WILDCARD,
Expand Down Expand Up @@ -209,7 +210,7 @@ class AttrType:
@property
def name(self) -> str:
"""Shortcut for qname local name."""
return split_qname(self.qname)[1]
return local_name(self.qname)

@property
def is_dependency(self) -> bool:
Expand Down Expand Up @@ -431,11 +432,11 @@ class Class:
@property
def name(self) -> str:
"""Shortcut for qname local name."""
return split_qname(self.qname)[1]
return local_name(self.qname)

@property
def target_namespace(self) -> Optional[str]:
return split_qname(self.qname)[0]
return target_uri(self.qname)

@property
def has_suffix_attr(self) -> bool:
Expand Down
17 changes: 9 additions & 8 deletions xsdata/codegen/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from xsdata.codegen.models import Import
from xsdata.exceptions import ResolverValueError
from xsdata.utils import collections
from xsdata.utils.namespaces import split_qname
from xsdata.utils.namespaces import local_name
from xsdata.utils.text import alnum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,25 +74,25 @@ def apply_aliases(self, target: Class):
def resolve_imports(self):
"""Walk the import qualified names, check for naming collisions and add
the necessary code generator import instance."""
local_names = {split_qname(qname)[1] for qname in self.class_map.keys()}
existing = {alnum(local_name(qname)) for qname in self.class_map.keys()}
for qname in self.import_classes():
package = self.find_package(qname)
local_name = split_qname(qname)[1]
exists = local_name in local_names
local_names.add(local_name)
name = alnum(local_name(qname))
exists = name in existing
existing.add(name)
self.add_import(qname=qname, package=package, exists=exists)

def add_import(self, qname: str, package: str, exists: bool = False):
"""Append an import package to the list of imports with any if
necessary aliases if the import name exists in the local module."""
alias = None
local_name = split_qname(qname)[1]
name = local_name(qname)
if exists:
module = package.split(".")[-1]
alias = f"{module}:{local_name}"
alias = f"{module}:{name}"
self.aliases[qname] = alias

self.imports.append(Import(name=local_name, source=package, alias=alias))
self.imports.append(Import(name=name, source=package, alias=alias))

def find_package(self, qname: str) -> str:
"""
Expand Down
19 changes: 8 additions & 11 deletions xsdata/codegen/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from xsdata.models.enums import DataType
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils import text
from xsdata.utils.collections import first
from xsdata.utils.collections import group_by
from xsdata.utils.namespaces import build_qname
from xsdata.utils.namespaces import clean_uri
from xsdata.utils.namespaces import split_qname
from xsdata.utils.text import alnum


@dataclass
Expand Down Expand Up @@ -154,7 +154,7 @@ def process_attribute_default_enum(self, target: Class, attr: Attr):
def resolve_conflicts(self):
"""Find classes with the same case insensitive qualified name and
rename them."""
groups = group_by(self.container.iterate(), lambda x: text.snake_case(x.qname))
groups = group_by(self.container.iterate(), lambda x: alnum(x.qname))
for classes in groups.values():
if len(classes) > 1:
self.rename_classes(classes)
Expand Down Expand Up @@ -189,11 +189,11 @@ def next_qname(self, namespace: str, name: str) -> str:
"""Append the next available index number for the given namespace and
local name."""
index = 0
reserved = map(text.snake_case, self.container.data.keys())
reserved = set(map(alnum, self.container.data.keys()))
while True:
index += 1
qname = build_qname(namespace, f"{name}_{index}")
if text.snake_case(qname) not in reserved:
if alnum(qname) not in reserved:
return qname

def rename_class_dependencies(self, target: Class, search: str, replace: str):
Expand Down Expand Up @@ -291,7 +291,7 @@ def process_attribute_sequence(cls, target: Class, attr: Attr):
def process_duplicate_attribute_names(cls, attrs: List[Attr]) -> None:
"""Sanitize duplicate attribute names that might exist by applying
rename strategies."""
grouped = group_by(attrs, lambda attr: text.snake_case(attr.name))
grouped = group_by(attrs, lambda attr: alnum(attr.name))
for items in grouped.values():
total = len(items)
if total == 2 and not items[0].is_enumeration:
Expand All @@ -305,15 +305,12 @@ def rename_attributes_with_index(cls, attrs: List[Attr], rename: List[Attr]):
names."""
for index in range(1, len(rename)):
num = 1
name = text.snake_case(rename[index].name)
name = rename[index].name

while any(
text.snake_case(attr.name) == text.snake_case(f"{name}_{num}")
for attr in attrs
):
while any(alnum(attr.name) == alnum(f"{name}_{num}") for attr in attrs):
num += 1

rename[index].name = f"{rename[index].name}_{num}"
rename[index].name = f"{name}_{num}"

@classmethod
def rename_attribute_by_preference(cls, a: Attr, b: Attr):
Expand Down
4 changes: 2 additions & 2 deletions xsdata/formats/dataclass/parsers/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xsdata.formats.dataclass.parsers.nodes import NodeParser
from xsdata.formats.dataclass.parsers.nodes import Parsed
from xsdata.models.enums import EventType
from xsdata.utils.namespaces import split_qname
from xsdata.utils.namespaces import local_name
from xsdata.utils.text import snake_case


Expand Down Expand Up @@ -101,7 +101,7 @@ def emit_event(self, event: str, name: str, **kwargs: Any):

key = (event, name)
if key not in self.emit_cache:
method_name = f"{event}_{snake_case(split_qname(name)[1])}"
method_name = f"{event}_{snake_case(local_name(name))}"
self.emit_cache[key] = getattr(self, method_name, None)

method = self.emit_cache[key]
Expand Down
14 changes: 11 additions & 3 deletions xsdata/utils/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple

from xsdata.models.enums import Namespace
from xsdata.utils.text import split
from xsdata.utils import text


__uri_ignore__ = ("www", "xsd", "wsdl")
Expand Down Expand Up @@ -61,7 +61,7 @@ def clean_uri(namespace: str) -> str:
if namespace[:2] == "##":
namespace = namespace[2:]

left, right = split(namespace)
left, right = text.split(namespace)

if left == "urn":
namespace = right
Expand All @@ -88,8 +88,16 @@ def build_qname(tag_or_uri: Optional[str], tag: Optional[str] = None) -> str:
def split_qname(tag: str) -> Tuple:
"""Split namespace qualified strings."""
if tag[0] == "{":
left, right = split(tag[1:], "}")
left, right = text.split(tag[1:], "}")
if left:
return left, right

return None, tag


def target_uri(tag: str) -> Optional[str]:
return split_qname(tag)[0]


def local_name(tag: str) -> str:
return split_qname(tag)[1]
Loading

0 comments on commit 7153e6e

Please sign in to comment.