From 542b0a82ccb31f4e56025cd767bfb2241ebd347d Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 19 Feb 2021 22:29:49 +0200 Subject: [PATCH 1/5] Root xs:element dont have occurrences --- tests/codegen/parsers/test_schema.py | 12 +++ xsdata/codegen/parsers/schema.py | 125 ++++++++++++++------------- xsdata/models/xsd.py | 4 +- 3 files changed, 80 insertions(+), 61 deletions(-) diff --git a/tests/codegen/parsers/test_schema.py b/tests/codegen/parsers/test_schema.py index ffaf4743a..46efceae2 100644 --- a/tests/codegen/parsers/test_schema.py +++ b/tests/codegen/parsers/test_schema.py @@ -403,8 +403,20 @@ def test_end_schema( mock_resolve_schemas_locations, ): schema = Schema() + schema.elements.append(Element()) + schema.elements.append(Element()) + schema.elements.append(Element()) + + for el in schema.elements: + self.assertEqual(1, el.min_occurs) + self.assertEqual(1, el.max_occurs) self.parser.end_schema(schema) + + for el in schema.elements: + self.assertIsNone(el.min_occurs) + self.assertIsNone(el.max_occurs) + self.parser.end_schema(ComplexType()) mock_set_schema_forms.assert_called_once_with(schema) diff --git a/xsdata/codegen/parsers/schema.py b/xsdata/codegen/parsers/schema.py index ba9b6c782..aa05b86b4 100644 --- a/xsdata/codegen/parsers/schema.py +++ b/xsdata/codegen/parsers/schema.py @@ -82,6 +82,66 @@ def start_schema(self, attrs: Dict): self.attribute_form = attrs.get("attributeFormDefault", None) self.default_attributes = attrs.get("defaultAttributes", None) + def end_schema(self, obj: T): + """Normalize various properties for the schema and it's children.""" + if isinstance(obj, xsd.Schema): + self.set_schema_forms(obj) + self.set_schema_namespaces(obj) + self.add_default_imports(obj) + self.resolve_schemas_locations(obj) + self.reset_element_occurs(obj) + + def end_attribute(self, obj: T): + """Assign the schema's default form for attributes if the given + attribute form is None.""" + if isinstance(obj, xsd.Attribute) and obj.form is None and self.attribute_form: + obj.form = FormType(self.attribute_form) + + def end_complex_type(self, obj: T): + """Prepend an attribute group reference when default attributes + apply.""" + if not isinstance(obj, xsd.ComplexType): + return + + if obj.default_attributes_apply and self.default_attributes: + attribute_group = xsd.AttributeGroup(ref=self.default_attributes) + obj.attribute_groups.insert(0, attribute_group) + + if not obj.open_content and not obj.complex_content: + obj.open_content = self.default_open_content + + def end_default_open_content(self, obj: T): + """Set the instance default open content to be used later as a property + for all extensions and restrictions.""" + if isinstance(obj, xsd.DefaultOpenContent): + if obj.any and obj.mode == Mode.SUFFIX: + obj.any.index = sys.maxsize + + self.default_open_content = obj + + def end_element(self, obj: T): + """Assign the schema's default form for elements if the given element + form is None.""" + if isinstance(obj, xsd.Element) and obj.form is None and self.element_form: + obj.form = FormType(self.element_form) + + def end_extension(self, obj: T): + """Set the open content if any to the given extension.""" + if isinstance(obj, xsd.Extension) and not obj.open_content: + obj.open_content = self.default_open_content + + @classmethod + def end_open_content(cls, obj: T): + """Adjust the index to trick later processors into putting attributes + derived from this open content last in classes.""" + if isinstance(obj, xsd.OpenContent) and obj.any and obj.mode == Mode.SUFFIX: + obj.any.index = sys.maxsize + + def end_restriction(self, obj: T): + """Set the open content if any to the given restriction.""" + if isinstance(obj, xsd.Restriction) and not obj.open_content: + obj.open_content = self.default_open_content + def set_schema_forms(self, obj: xsd.Schema): """ Set the default form type for elements and attributes. @@ -134,6 +194,12 @@ def add_default_imports(obj: xsd.Schema): if xsi_ns in obj.ns_map.values() and xsi_ns not in imp_namespaces: obj.imports.insert(0, xsd.Import(namespace=xsi_ns)) + @staticmethod + def reset_element_occurs(obj: xsd.Schema): + for element in obj.elements: + element.min_occurs = None + element.max_occurs = None + def resolve_schemas_locations(self, obj: xsd.Schema): """Resolve the locations of the schema overrides, redefines, includes and imports relatively to the schema location.""" @@ -172,62 +238,3 @@ def resolve_local_path( return local_path return self.resolve_path(location) - - def end_attribute(self, obj: T): - """Assign the schema's default form for attributes if the given - attribute form is None.""" - if isinstance(obj, xsd.Attribute) and obj.form is None and self.attribute_form: - obj.form = FormType(self.attribute_form) - - def end_complex_type(self, obj: T): - """Prepend an attribute group reference when default attributes - apply.""" - if not isinstance(obj, xsd.ComplexType): - return - - if obj.default_attributes_apply and self.default_attributes: - attribute_group = xsd.AttributeGroup(ref=self.default_attributes) - obj.attribute_groups.insert(0, attribute_group) - - if not obj.open_content and not obj.complex_content: - obj.open_content = self.default_open_content - - def end_default_open_content(self, obj: T): - """Set the instance default open content to be used later as a property - for all extensions and restrictions.""" - if isinstance(obj, xsd.DefaultOpenContent): - if obj.any and obj.mode == Mode.SUFFIX: - obj.any.index = sys.maxsize - - self.default_open_content = obj - - def end_element(self, obj: T): - """Assign the schema's default form for elements if the given element - form is None.""" - if isinstance(obj, xsd.Element) and obj.form is None and self.element_form: - obj.form = FormType(self.element_form) - - def end_extension(self, obj: T): - """Set the open content if any to the given extension.""" - if isinstance(obj, xsd.Extension) and not obj.open_content: - obj.open_content = self.default_open_content - - @classmethod - def end_open_content(cls, obj: T): - """Adjust the index to trick later processors into putting attributes - derived from this open content last in classes.""" - if isinstance(obj, xsd.OpenContent) and obj.any and obj.mode == Mode.SUFFIX: - obj.any.index = sys.maxsize - - def end_restriction(self, obj: T): - """Set the open content if any to the given restriction.""" - if isinstance(obj, xsd.Restriction) and not obj.open_content: - obj.open_content = self.default_open_content - - def end_schema(self, obj: T): - """Normalize various properties for the schema and it's children.""" - if isinstance(obj, xsd.Schema): - self.set_schema_forms(obj) - self.set_schema_namespaces(obj) - self.add_default_imports(obj) - self.resolve_schemas_locations(obj) diff --git a/xsdata/models/xsd.py b/xsdata/models/xsd.py index 95377e6ad..f038c106f 100644 --- a/xsdata/models/xsd.py +++ b/xsdata/models/xsd.py @@ -1083,8 +1083,8 @@ class Element(AnnotationBase): uniques: Array[Unique] = array_element(name="unique") keys: Array[Key] = array_element(name="key") keyrefs: Array[Keyref] = array_element(name="keyref") - min_occurs: int = attribute(default=1, name="minOccurs") - max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs") + min_occurs: Optional[int] = attribute(default=1, name="minOccurs") + max_occurs: UnionType[None, int, str] = attribute(default=1, name="maxOccurs") nillable: bool = attribute(default=False) abstract: bool = attribute(default=False) From f106cb9262590ca301d0dcb807a86812f07d4c9a Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 19 Feb 2021 22:32:31 +0200 Subject: [PATCH 2/5] Default value fields are required --- tests/codegen/handlers/test_class_extension.py | 4 +++- xsdata/codegen/handlers/class_extension.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/codegen/handlers/test_class_extension.py b/tests/codegen/handlers/test_class_extension.py index 8d65b26a6..21732feb6 100644 --- a/tests/codegen/handlers/test_class_extension.py +++ b/tests/codegen/handlers/test_class_extension.py @@ -447,7 +447,9 @@ def test_add_default_attribute(self): ClassExtensionHandler.add_default_attribute(item, extension) expected.types.append(xs_int) - expected_restrictions = Restrictions(tokens=True, required=True) + expected_restrictions = Restrictions( + tokens=True, required=True, min_occurs=1, max_occurs=1 + ) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) diff --git a/xsdata/codegen/handlers/class_extension.py b/xsdata/codegen/handlers/class_extension.py index f57c1a400..e61c0a7df 100644 --- a/xsdata/codegen/handlers/class_extension.py +++ b/xsdata/codegen/handlers/class_extension.py @@ -271,5 +271,7 @@ def get_or_create_attribute(cls, target: Class, name: str, tag: str) -> Attr: return attr attr = Attr(name=name, tag=tag) + attr.restrictions.min_occurs = 1 + attr.restrictions.max_occurs = 1 target.attrs.insert(0, attr) return attr From 99fd2c359894e6811200ccc6dd6d3d07c6b36bb0 Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 20 Feb 2021 21:03:21 +0200 Subject: [PATCH 3/5] Add more cases for restrictions sanitize --- tests/codegen/test_sanitizer.py | 85 ++++++++++++++------ tests/fixtures/defxmlschema/chapter03prod.py | 6 ++ tests/fixtures/defxmlschema/chapter04prod.py | 3 + tests/fixtures/defxmlschema/chapter05prod.py | 4 +- tests/fixtures/defxmlschema/chapter12.py | 3 + tests/fixtures/defxmlschema/chapter13.py | 3 + tests/fixtures/defxmlschema/chapter15.py | 3 + tests/fixtures/defxmlschema/chapter16.py | 6 ++ tests/fixtures/defxmlschema/chapter17.py | 3 + tests/models/xsd/test_attribute.py | 2 +- xsdata/codegen/mappers/definitions.py | 3 + xsdata/codegen/sanitizer.py | 27 +++++-- xsdata/models/xsd.py | 2 +- 13 files changed, 114 insertions(+), 36 deletions(-) diff --git a/tests/codegen/test_sanitizer.py b/tests/codegen/test_sanitizer.py index b22dd308a..c485734bd 100644 --- a/tests/codegen/test_sanitizer.py +++ b/tests/codegen/test_sanitizer.py @@ -93,13 +93,6 @@ def test_process_attribute_default_with_enumeration(self): self.sanitizer.process_attribute_default(target, attr) self.assertTrue(attr.fixed) - def test_process_attribute_default_with_list_field(self): - target = ClassFactory.create() - attr = AttrFactory.create(fixed=True) - attr.restrictions.max_occurs = 2 - self.sanitizer.process_attribute_default(target, attr) - self.assertFalse(attr.fixed) - def test_process_attribute_default_with_optional_field(self): target = ClassFactory.create() attr = AttrFactory.create(fixed=True, default=2) @@ -208,25 +201,67 @@ def test_find_enum(self): self.assertIsNone(actual) def test_process_attribute_restrictions(self): - restrictions = [ - Restrictions(min_occurs=0, max_occurs=0, required=True), - Restrictions(min_occurs=0, max_occurs=1, required=True), - Restrictions(min_occurs=1, max_occurs=1, required=False), - Restrictions(max_occurs=2, required=True), - Restrictions(min_occurs=2, max_occurs=2, required=True), - ] - expected = [ - {}, - {}, - {"required": True}, - {"max_occurs": 2}, - {"max_occurs": 2, "min_occurs": 2}, - ] + required = Restrictions(min_occurs=1, max_occurs=1) + attr = AttrFactory.attribute(restrictions=required.clone()) + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + tokens = Restrictions(required=True, tokens=True, min_occurs=1, max_occurs=1) + attr = AttrFactory.element(restrictions=tokens.clone()) + self.sanitizer.process_attribute_restrictions(attr) + self.assertFalse(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + attr = AttrFactory.element(restrictions=tokens.clone()) + attr.restrictions.max_occurs = 2 + self.sanitizer.process_attribute_restrictions(attr) + self.assertFalse(attr.restrictions.required) + self.assertIsNotNone(attr.restrictions.min_occurs) + self.assertIsNotNone(attr.restrictions.max_occurs) + + multiple = Restrictions(min_occurs=0, max_occurs=2) + attr = AttrFactory.create(tag=Tag.EXTENSION, restrictions=multiple) + self.sanitizer.process_attribute_restrictions(attr) + self.assertTrue(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + multiple = Restrictions(max_occurs=2, required=True) + attr = AttrFactory.element(restrictions=multiple, fixed=True) + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.required) + self.assertEqual(0, attr.restrictions.min_occurs) + self.assertFalse(attr.fixed) + + attr = AttrFactory.element(restrictions=required.clone()) + self.sanitizer.process_attribute_restrictions(attr) + self.assertTrue(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + restrictions = Restrictions(required=True, min_occurs=0, max_occurs=1) + attr = AttrFactory.element(restrictions=restrictions, default="A", fixed=True) + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + self.assertIsNone(attr.default) + self.assertFalse(attr.fixed) + + attr = AttrFactory.element(restrictions=required.clone(), default="A") + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.required) + + attr = AttrFactory.element(restrictions=required.clone(), fixed=True) + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.required) - for idx, res in enumerate(restrictions): - attr = AttrFactory.create(restrictions=res) - self.sanitizer.process_attribute_restrictions(attr) - self.assertEqual(expected[idx], res.asdict()) + attr = AttrFactory.element(restrictions=required.clone()) + attr.restrictions.nillable = True + self.sanitizer.process_attribute_restrictions(attr) + self.assertIsNone(attr.restrictions.required) def test_sanitize_duplicate_attribute_names(self): attrs = [ diff --git a/tests/fixtures/defxmlschema/chapter03prod.py b/tests/fixtures/defxmlschema/chapter03prod.py index 5db296f8c..51fd2e7d2 100644 --- a/tests/fixtures/defxmlschema/chapter03prod.py +++ b/tests/fixtures/defxmlschema/chapter03prod.py @@ -9,6 +9,9 @@ class ProdNumType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) id: Optional[str] = field( default=None, @@ -24,6 +27,9 @@ class ProdNumType: class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter04prod.py b/tests/fixtures/defxmlschema/chapter04prod.py index 0d1d6edac..3c091b645 100644 --- a/tests/fixtures/defxmlschema/chapter04prod.py +++ b/tests/fixtures/defxmlschema/chapter04prod.py @@ -18,6 +18,9 @@ class ColorType: class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter05prod.py b/tests/fixtures/defxmlschema/chapter05prod.py index a8e4b2ed7..155691fc3 100644 --- a/tests/fixtures/defxmlschema/chapter05prod.py +++ b/tests/fixtures/defxmlschema/chapter05prod.py @@ -8,6 +8,9 @@ class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, @@ -40,7 +43,6 @@ class ProductType: metadata={ "type": "Element", "namespace": "", - "required": True, "nillable": True, } ) diff --git a/tests/fixtures/defxmlschema/chapter12.py b/tests/fixtures/defxmlschema/chapter12.py index f1af31ab8..974fbee14 100644 --- a/tests/fixtures/defxmlschema/chapter12.py +++ b/tests/fixtures/defxmlschema/chapter12.py @@ -29,6 +29,9 @@ class DescriptionType: class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter13.py b/tests/fixtures/defxmlschema/chapter13.py index 73035ed4b..43e1b0241 100644 --- a/tests/fixtures/defxmlschema/chapter13.py +++ b/tests/fixtures/defxmlschema/chapter13.py @@ -76,6 +76,9 @@ class RestrictedProductType: class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter15.py b/tests/fixtures/defxmlschema/chapter15.py index b6d9e2553..e68682ac6 100644 --- a/tests/fixtures/defxmlschema/chapter15.py +++ b/tests/fixtures/defxmlschema/chapter15.py @@ -8,6 +8,9 @@ class SizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter16.py b/tests/fixtures/defxmlschema/chapter16.py index 2fc0f94f0..196f7498b 100644 --- a/tests/fixtures/defxmlschema/chapter16.py +++ b/tests/fixtures/defxmlschema/chapter16.py @@ -16,6 +16,9 @@ class ColorType: class HatSizeType: value: Optional[str] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, @@ -49,6 +52,9 @@ class ProductType: class ShirtSizeType: value: Optional[int] = field( default=None, + metadata={ + "required": True, + } ) system: Optional[str] = field( default=None, diff --git a/tests/fixtures/defxmlschema/chapter17.py b/tests/fixtures/defxmlschema/chapter17.py index aa1c278c5..ec6f6a182 100644 --- a/tests/fixtures/defxmlschema/chapter17.py +++ b/tests/fixtures/defxmlschema/chapter17.py @@ -17,6 +17,9 @@ class ColorType: class PriceType: value: Optional[Decimal] = field( default=None, + metadata={ + "required": True, + } ) currency: Optional[str] = field( default=None, diff --git a/tests/models/xsd/test_attribute.py b/tests/models/xsd/test_attribute.py index e78652856..e39534091 100644 --- a/tests/models/xsd/test_attribute.py +++ b/tests/models/xsd/test_attribute.py @@ -44,7 +44,7 @@ def test_get_restrictions(self): self.assertEqual({}, obj.get_restrictions()) obj.use = UseType.REQUIRED - expected = {"max_occurs": 1, "min_occurs": 1, "required": True} + expected = {"required": True} self.assertEqual(expected, obj.get_restrictions()) obj.use = UseType.PROHIBITED diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index ae5a0f2b3..cc1be6568 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -7,6 +7,7 @@ from xsdata.codegen.models import Attr from xsdata.codegen.models import AttrType from xsdata.codegen.models import Class +from xsdata.codegen.models import Restrictions from xsdata.codegen.models import Status from xsdata.formats.dataclass.models.generics import AnyElement from xsdata.logger import logger @@ -368,10 +369,12 @@ def build_attr( default: Optional[str] = None, ) -> Attr: """Builder method for attributes.""" + occurs = 1 if default is not None else None return Attr( tag=Tag.ELEMENT, name=name, namespace=namespace, default=default, types=[AttrType(qname=qname, forward=forward, native=native)], + restrictions=Restrictions(min_occurs=occurs, max_occurs=occurs), ) diff --git a/xsdata/codegen/sanitizer.py b/xsdata/codegen/sanitizer.py index 4551ceacb..4941289cc 100644 --- a/xsdata/codegen/sanitizer.py +++ b/xsdata/codegen/sanitizer.py @@ -53,8 +53,8 @@ def process_class(self, target: Class): self.group_compound_fields(target) for attr in target.attrs: - self.process_attribute_default(target, attr) self.process_attribute_restrictions(attr) + self.process_attribute_default(target, attr) self.process_attribute_sequence(target, attr) self.process_duplicate_attribute_names(target.attrs) @@ -112,9 +112,6 @@ def process_attribute_default(self, target: Class, attr: Attr): if attr.is_enumeration: return - if attr.is_list: - attr.fixed = False - if attr.is_optional or attr.is_xsi_type: attr.fixed = False attr.default = None @@ -243,17 +240,31 @@ def process_attribute_restrictions(cls, attr: Attr): min_occurs = restrictions.min_occurs or 0 max_occurs = restrictions.max_occurs or 0 - if min_occurs == 0 and max_occurs <= 1: - restrictions.required = None + if attr.is_attribute: restrictions.min_occurs = None restrictions.max_occurs = None - if min_occurs == 1 and max_occurs == 1: + elif attr.is_tokens: + restrictions.required = None + if max_occurs <= 1: + restrictions.min_occurs = None + restrictions.max_occurs = None + elif attr.xml_type is None or min_occurs == max_occurs == 1: restrictions.required = True restrictions.min_occurs = None restrictions.max_occurs = None - elif restrictions.max_occurs and max_occurs > 1: + elif min_occurs == 0 and max_occurs < 2: + restrictions.required = None + restrictions.min_occurs = None + restrictions.max_occurs = None + attr.default = None + attr.fixed = False + else: # max_occurs > 1 restrictions.min_occurs = min_occurs restrictions.required = None + attr.fixed = False + + if attr.default or attr.fixed or attr.restrictions.nillable: + restrictions.required = None @classmethod def process_attribute_sequence(cls, target: Class, attr: Attr): diff --git a/xsdata/models/xsd.py b/xsdata/models/xsd.py index f038c106f..0d5c835e2 100644 --- a/xsdata/models/xsd.py +++ b/xsdata/models/xsd.py @@ -326,7 +326,7 @@ def attr_types(self) -> Iterator[str]: def get_restrictions(self) -> Dict[str, Anything]: restrictions = {} if self.use == UseType.REQUIRED: - restrictions.update({"min_occurs": 1, "max_occurs": 1, "required": True}) + restrictions.update({"required": True}) elif self.use == UseType.PROHIBITED: restrictions.update({"prohibited": True}) From 3f111e22ad1f0d0c4aef99962a2e0f250f96d005 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 16 Feb 2021 17:12:47 +0200 Subject: [PATCH 4/5] Stricter duplicate class/attr name comparison --- tests/codegen/test_resolver.py | 4 +- tests/utils/test_text.py | 8 +++ xsdata/codegen/mappers/definitions.py | 12 ++--- xsdata/codegen/models.py | 9 ++-- xsdata/codegen/resolver.py | 17 ++++--- xsdata/codegen/sanitizer.py | 19 +++---- xsdata/formats/dataclass/parsers/xml.py | 4 +- xsdata/utils/namespaces.py | 14 ++++-- xsdata/utils/text.py | 67 ++++++++++++++----------- 9 files changed, 90 insertions(+), 64 deletions(-) diff --git a/tests/codegen/test_resolver.py b/tests/codegen/test_resolver.py index 794fae427..35a3b93bd 100644 --- a/tests/codegen/test_resolver.py +++ b/tests/codegen/test_resolver.py @@ -135,9 +135,9 @@ def test_resolve_imports( ): class_life = ClassFactory.create(qname="life") import_names = [ - "foo", # cool + "foo_1", # cool "bar", # cool - "{another}foo", # another foo + "{another}foo1", # another foo "{thug}life", # life class exists add alias "{common}type", # type class doesn't exist add just the name ] diff --git a/tests/utils/test_text.py b/tests/utils/test_text.py index c8c825279..fa0ea58f8 100644 --- a/tests/utils/test_text.py +++ b/tests/utils/test_text.py @@ -1,5 +1,6 @@ from unittest import TestCase +from xsdata.utils.text import alnum from xsdata.utils.text import camel_case from xsdata.utils.text import capitalize from xsdata.utils.text import mixed_case @@ -99,3 +100,10 @@ def test_split_words(self): self.assertEqual(["user"], split_words("__user")) self.assertEqual(["TMessage", "DB"], split_words("TMessageDB")) self.assertEqual(["GLOBAL", "REF"], split_words("GLOBAL-REF")) + + def test_alnum(self): + self.assertEqual("foo1", alnum("foo 1")) + self.assertEqual("foo1", alnum(" foo_1 ")) + self.assertEqual("foo1", alnum("\tfoo*1")) + self.assertEqual("foo1", alnum(" foo*1")) + self.assertEqual("βιβλίο1", alnum(" βιβλίο*1")) diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index cc1be6568..6e2d7e051 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -27,7 +27,7 @@ from xsdata.utils import text from xsdata.utils.collections import first from xsdata.utils.namespaces import build_qname -from xsdata.utils.namespaces import split_qname +from xsdata.utils.namespaces import local_name class DefinitionsMapper: @@ -216,10 +216,10 @@ def build_envelope_class( for ext in binding_message.extended_elements: assert ext.qname is not None - local_name = split_qname(ext.qname)[1].title() - inner = cls.build_inner_class(target, local_name) + class_name = local_name(ext.qname).title() + inner = cls.build_inner_class(target, class_name) - if style == "rpc" and local_name == "Body": + if style == "rpc" and class_name == "Body": namespace = ext.attributes.get("namespace") attrs = cls.map_port_type_message(port_type_message, namespace) else: @@ -300,7 +300,7 @@ def map_binding_message_parts( parts.extend(extended.attributes["parts"].split()) if "message" in extended.attributes: - message_name = split_qname(extended.attributes["message"])[1] + message_name = local_name(extended.attributes["message"]) else: message_name = text.suffix(message) @@ -352,7 +352,7 @@ def operation_namespace(cls, config: Dict) -> Optional[str]: def attributes(cls, elements: Iterator[AnyElement]) -> Dict: """Return all attributes from all extended elements as a dictionary.""" return { - split_qname(qname)[1]: value + local_name(qname): value for element in elements if isinstance(element, AnyElement) for qname, value in element.attributes.items() diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index f38fa3e6f..fa222aa5a 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -19,7 +19,8 @@ from xsdata.models.enums import Tag from xsdata.models.mixins import ElementBase from xsdata.utils.namespaces import build_qname -from xsdata.utils.namespaces import split_qname +from xsdata.utils.namespaces import local_name +from xsdata.utils.namespaces import target_uri xml_type_map = { Tag.ANY: XmlType.WILDCARD, @@ -209,7 +210,7 @@ class AttrType: @property def name(self) -> str: """Shortcut for qname local name.""" - return split_qname(self.qname)[1] + return local_name(self.qname) @property def is_dependency(self) -> bool: @@ -431,11 +432,11 @@ class Class: @property def name(self) -> str: """Shortcut for qname local name.""" - return split_qname(self.qname)[1] + return local_name(self.qname) @property def target_namespace(self) -> Optional[str]: - return split_qname(self.qname)[0] + return target_uri(self.qname) @property def has_suffix_attr(self) -> bool: diff --git a/xsdata/codegen/resolver.py b/xsdata/codegen/resolver.py index 1bacf44ea..52f9867f6 100644 --- a/xsdata/codegen/resolver.py +++ b/xsdata/codegen/resolver.py @@ -10,7 +10,8 @@ from xsdata.codegen.models import Import from xsdata.exceptions import ResolverValueError from xsdata.utils import collections -from xsdata.utils.namespaces import split_qname +from xsdata.utils.namespaces import local_name +from xsdata.utils.text import alnum logger = logging.getLogger(__name__) @@ -73,25 +74,25 @@ def apply_aliases(self, target: Class): def resolve_imports(self): """Walk the import qualified names, check for naming collisions and add the necessary code generator import instance.""" - local_names = {split_qname(qname)[1] for qname in self.class_map.keys()} + existing = {alnum(local_name(qname)) for qname in self.class_map.keys()} for qname in self.import_classes(): package = self.find_package(qname) - local_name = split_qname(qname)[1] - exists = local_name in local_names - local_names.add(local_name) + name = alnum(local_name(qname)) + exists = name in existing + existing.add(name) self.add_import(qname=qname, package=package, exists=exists) def add_import(self, qname: str, package: str, exists: bool = False): """Append an import package to the list of imports with any if necessary aliases if the import name exists in the local module.""" alias = None - local_name = split_qname(qname)[1] + name = local_name(qname) if exists: module = package.split(".")[-1] - alias = f"{module}:{local_name}" + alias = f"{module}:{name}" self.aliases[qname] = alias - self.imports.append(Import(name=local_name, source=package, alias=alias)) + self.imports.append(Import(name=name, source=package, alias=alias)) def find_package(self, qname: str) -> str: """ diff --git a/xsdata/codegen/sanitizer.py b/xsdata/codegen/sanitizer.py index 4941289cc..b76aa0eec 100644 --- a/xsdata/codegen/sanitizer.py +++ b/xsdata/codegen/sanitizer.py @@ -11,12 +11,12 @@ from xsdata.models.enums import DataType from xsdata.models.enums import Tag from xsdata.utils import collections -from xsdata.utils import text from xsdata.utils.collections import first from xsdata.utils.collections import group_by from xsdata.utils.namespaces import build_qname from xsdata.utils.namespaces import clean_uri from xsdata.utils.namespaces import split_qname +from xsdata.utils.text import alnum @dataclass @@ -154,7 +154,7 @@ def process_attribute_default_enum(self, target: Class, attr: Attr): def resolve_conflicts(self): """Find classes with the same case insensitive qualified name and rename them.""" - groups = group_by(self.container.iterate(), lambda x: text.snake_case(x.qname)) + groups = group_by(self.container.iterate(), lambda x: alnum(x.qname)) for classes in groups.values(): if len(classes) > 1: self.rename_classes(classes) @@ -189,11 +189,11 @@ def next_qname(self, namespace: str, name: str) -> str: """Append the next available index number for the given namespace and local name.""" index = 0 - reserved = map(text.snake_case, self.container.data.keys()) + reserved = set(map(alnum, self.container.data.keys())) while True: index += 1 qname = build_qname(namespace, f"{name}_{index}") - if text.snake_case(qname) not in reserved: + if alnum(qname) not in reserved: return qname def rename_class_dependencies(self, target: Class, search: str, replace: str): @@ -291,7 +291,7 @@ def process_attribute_sequence(cls, target: Class, attr: Attr): def process_duplicate_attribute_names(cls, attrs: List[Attr]) -> None: """Sanitize duplicate attribute names that might exist by applying rename strategies.""" - grouped = group_by(attrs, lambda attr: text.snake_case(attr.name)) + grouped = group_by(attrs, lambda attr: alnum(attr.name)) for items in grouped.values(): total = len(items) if total == 2 and not items[0].is_enumeration: @@ -305,15 +305,12 @@ def rename_attributes_with_index(cls, attrs: List[Attr], rename: List[Attr]): names.""" for index in range(1, len(rename)): num = 1 - name = text.snake_case(rename[index].name) + name = rename[index].name - while any( - text.snake_case(attr.name) == text.snake_case(f"{name}_{num}") - for attr in attrs - ): + while any(alnum(attr.name) == alnum(f"{name}_{num}") for attr in attrs): num += 1 - rename[index].name = f"{rename[index].name}_{num}" + rename[index].name = f"{name}_{num}" @classmethod def rename_attribute_by_preference(cls, a: Attr, b: Attr): diff --git a/xsdata/formats/dataclass/parsers/xml.py b/xsdata/formats/dataclass/parsers/xml.py index 28c5ca5f9..41428f02b 100644 --- a/xsdata/formats/dataclass/parsers/xml.py +++ b/xsdata/formats/dataclass/parsers/xml.py @@ -12,7 +12,7 @@ from xsdata.formats.dataclass.parsers.nodes import NodeParser from xsdata.formats.dataclass.parsers.nodes import Parsed from xsdata.models.enums import EventType -from xsdata.utils.namespaces import split_qname +from xsdata.utils.namespaces import local_name from xsdata.utils.text import snake_case @@ -101,7 +101,7 @@ def emit_event(self, event: str, name: str, **kwargs: Any): key = (event, name) if key not in self.emit_cache: - method_name = f"{event}_{snake_case(split_qname(name)[1])}" + method_name = f"{event}_{snake_case(local_name(name))}" self.emit_cache[key] = getattr(self, method_name, None) method = self.emit_cache[key] diff --git a/xsdata/utils/namespaces.py b/xsdata/utils/namespaces.py index f6feb4513..a10159e28 100644 --- a/xsdata/utils/namespaces.py +++ b/xsdata/utils/namespaces.py @@ -4,7 +4,7 @@ from typing import Tuple from xsdata.models.enums import Namespace -from xsdata.utils.text import split +from xsdata.utils import text __uri_ignore__ = ("www", "xsd", "wsdl") @@ -61,7 +61,7 @@ def clean_uri(namespace: str) -> str: if namespace[:2] == "##": namespace = namespace[2:] - left, right = split(namespace) + left, right = text.split(namespace) if left == "urn": namespace = right @@ -88,8 +88,16 @@ def build_qname(tag_or_uri: Optional[str], tag: Optional[str] = None) -> str: def split_qname(tag: str) -> Tuple: """Split namespace qualified strings.""" if tag[0] == "{": - left, right = split(tag[1:], "}") + left, right = text.split(tag[1:], "}") if left: return left, right return None, tag + + +def target_uri(tag: str) -> Optional[str]: + return split_qname(tag)[0] + + +def local_name(tag: str) -> str: + return split_qname(tag)[1] diff --git a/xsdata/utils/text.py b/xsdata/utils/text.py index 88f4892f7..a221ae4cd 100644 --- a/xsdata/utils/text.py +++ b/xsdata/utils/text.py @@ -1,78 +1,79 @@ import re +import string from typing import Any from typing import List from typing import Match from typing import Tuple -def prefix(string: str, sep: str = ":") -> str: +def prefix(value: str, sep: str = ":") -> str: """Return the first part of the string before the separator.""" - return split(string, sep)[0] + return split(value, sep)[0] -def suffix(string: str, sep: str = ":") -> str: +def suffix(value: str, sep: str = ":") -> str: """Return the last part of the string after the separator.""" - return split(string, sep)[1] + return split(value, sep)[1] -def split(string: str, sep: str = ":") -> Tuple: +def split(value: str, sep: str = ":") -> Tuple: """ Separate the given string with the given separator and return a tuple of the prefix and suffix. If the separator isn't present in the string return None as prefix. """ - left, _, right = string.partition(sep) + left, _, right = value.partition(sep) return (left, right) if right else (None, left) -def capitalize(string: str, **kwargs: Any) -> str: +def capitalize(value: str, **kwargs: Any) -> str: """Capitalize the given string.""" - return string[0].upper() + string[1:] + return value[0].upper() + value[1:] -def pascal_case(string: str, **kwargs: Any) -> str: +def pascal_case(value: str, **kwargs: Any) -> str: """Convert the given string to pascal case.""" - return "".join(map(str.title, split_words(string))) + return "".join(map(str.title, split_words(value))) -def camel_case(string: str, **kwargs: Any) -> str: +def camel_case(value: str, **kwargs: Any) -> str: """Convert the given string to camel case.""" - result = "".join(map(str.title, split_words(string))) + result = "".join(map(str.title, split_words(value))) return result[0].lower() + result[1:] -def mixed_case(string: str, **kwargs: Any) -> str: +def mixed_case(value: str, **kwargs: Any) -> str: """Convert the given string to mixed case.""" - return "".join(split_words(string)) + return "".join(split_words(value)) -def mixed_pascal_case(string: str, **kwargs: Any) -> str: +def mixed_pascal_case(value: str, **kwargs: Any) -> str: """Convert the given string to mixed pascal case.""" - return capitalize(mixed_case(string)) + return capitalize(mixed_case(value)) -def mixed_snake_case(string: str, **kwargs: Any) -> str: +def mixed_snake_case(value: str, **kwargs: Any) -> str: """Convert the given string to mixed snake case.""" - return "_".join(split_words(string)) + return "_".join(split_words(value)) -def snake_case(string: str, **kwargs: Any) -> str: +def snake_case(value: str, **kwargs: Any) -> str: """Convert the given string to snake case.""" - return "_".join(map(str.lower, split_words(string))) + return "_".join(map(str.lower, split_words(value))) -def screaming_snake_case(string: str, **kwargs: Any) -> str: +def screaming_snake_case(value: str, **kwargs: Any) -> str: """Convert the given string to screaming snake case.""" - return snake_case(string, **kwargs).upper() + return snake_case(value, **kwargs).upper() -def kebab_case(string: str, **kwargs: Any) -> str: +def kebab_case(value: str, **kwargs: Any) -> str: """Convert the given string to kebab case.""" - return "-".join(split_words(string)) + return "-".join(split_words(value)) -def split_words(string: str) -> List[str]: +def split_words(value: str) -> List[str]: """Split a string on new capital letters and not alphanumeric characters.""" words: List[str] = [] @@ -84,7 +85,7 @@ def flush(): words.append("".join(buffer)) buffer.clear() - for char in string: + for char in value: tp = classify(char) if tp == StringType.OTHER: flush() @@ -135,7 +136,7 @@ def classify(character: str) -> int: ESCAPE_DCT.setdefault(chr(i), f"\\u{i:04x}") -def escape_string(string: str) -> str: +def escape_string(value: str) -> str: """ Escape a string for code generation. @@ -145,4 +146,14 @@ def escape_string(string: str) -> str: def replace(match: Match) -> str: return ESCAPE_DCT[match.group(0)] - return ESCAPE.sub(replace, string) + return ESCAPE.sub(replace, value) + + +_punctuation = set(string.punctuation + string.whitespace) + + +def alnum(value: str) -> str: + for remove in set(value).intersection(_punctuation): + value = value.replace(remove, "") + + return value.lower() From ad4cbf2d2f193a6dfc226fb02f5c15b8a4bfe014 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 23 Feb 2021 22:39:40 +0200 Subject: [PATCH 5/5] Update pre-commit hooks --- .pre-commit-config.yaml | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0f54f6fb..73a81067e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: - id: docformatter args: ["--in-place", "--pre-summary-newline"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.800 + rev: v0.812 hooks: - id: mypy additional_dependencies: [tokenize-rt] diff --git a/setup.cfg b/setup.cfg index 82f9a5b41..985c5f75a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,7 @@ dev = pytest-cov tox docs = - sphinx==3.4.3 + sphinx sphinx-autobuild sphinx-autodoc-typehints sphinx-copybutton