diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fec20d35f..c6522c77a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: - id: docformatter args: ["--in-place", "--pre-summary-newline"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.902 + rev: v0.910 hooks: - id: mypy additional_dependencies: [tokenize-rt, types-requests, types-Jinja2, types-click] diff --git a/docs/_static/config.sample.xml b/docs/_static/config.sample.xml index dc9fca739..66fc0e859 100644 --- a/docs/_static/config.sample.xml +++ b/docs/_static/config.sample.xml @@ -2,9 +2,10 @@ generated - dataclasses + dataclasses filenames reStructuredText + false false diff --git a/docs/api/codegen.rst b/docs/api/codegen.rst index a53d25ea8..b216332ad 100644 --- a/docs/api/codegen.rst +++ b/docs/api/codegen.rst @@ -18,6 +18,7 @@ like naming conventions and aliases. GeneratorConfig GeneratorOutput + OutputFormat GeneratorConventions GeneratorAliases StructureStyle diff --git a/docs/codegen.rst b/docs/codegen.rst index c9e29447a..ea3532493 100644 --- a/docs/codegen.rst +++ b/docs/codegen.rst @@ -23,6 +23,7 @@ Generate Code - :ref:`Compound fields ` - :ref:`Docstring styles` + - :ref:`Dataclasses Features` .. code-block:: console @@ -96,8 +97,9 @@ altogether. .. warning:: - Auto :ref:`locating types ` during parsing might not work since - all classes are bundled together under the same module namespace. + Auto :ref:`locating types ` during parsing + might not work since all classes are bundled together under the same module + namespace. Initialize Config diff --git a/docs/examples.rst b/docs/examples.rst index b01ab6595..7709e186c 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -12,6 +12,7 @@ Code Generation examples/xml-modeling examples/json-modeling examples/compound-fields + examples/dataclasses-features Advance Topics diff --git a/docs/examples/dataclasses-features.rst b/docs/examples/dataclasses-features.rst new file mode 100644 index 000000000..bf186d4ef --- /dev/null +++ b/docs/examples/dataclasses-features.rst @@ -0,0 +1,41 @@ +==================== +Dataclasses Features +==================== + +By default xsdata with generate +`dataclasses `_ with the default +features on but you can use a :ref:`generator config ` to toggle +almost all of them. + + +.. literalinclude:: /../tests/fixtures/stripe/.xsdata.xml + :language: xml + :lines: 2-6 + + +.. tab:: Frozen Model + + The code generator will use tuples instead of lists as well. + + .. literalinclude:: /../tests/fixtures/stripe/models/balance.py + :language: python + :lines: 93-128 + +.. tab:: Frozen Bindings + + .. testcode:: + + import pprint + from tests import fixtures_dir + from tests.fixtures.stripe.models import Balance + from xsdata.formats.dataclass.parsers import JsonParser + + xml_path = fixtures_dir.joinpath("stripe/samples/balance.json") + parser = JsonParser() + root = parser.from_path(xml_path, Balance) + pprint.pprint(root.pending) + + .. testoutput:: + + (Pending(amount=835408472, currency='usd', source_types=SourceTypes(bank_account=0, card=835408472)), + Pending(amount=-22251, currency='eur', source_types=SourceTypes(bank_account=0, card=-22251))) diff --git a/docs/examples/json-modeling.rst b/docs/examples/json-modeling.rst index bc7025480..8a802b5ea 100644 --- a/docs/examples/json-modeling.rst +++ b/docs/examples/json-modeling.rst @@ -9,17 +9,17 @@ duplicate classes and their fields and their field types. .. code-block:: console - $ xsdata --package tests.fixtures.series tests/fixtures/series + $ xsdata --package tests.fixtures.series tests/fixtures/series/samples .. tab:: Sample #1 - .. literalinclude:: /../tests/fixtures/series/show1.json + .. literalinclude:: /../tests/fixtures/series/samples/show1.json :language: json .. tab:: Sample #2 - .. literalinclude:: /../tests/fixtures/series/show2.json + .. literalinclude:: /../tests/fixtures/series/samples/show2.json :language: json diff --git a/docs/models.rst b/docs/models.rst index 286d8eb45..15ba02ba3 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -124,7 +124,7 @@ Simply follow the Python lib .. warning:: - Currently only List, Dict and Union annotations are supported. + Currently only List, Tuple, Dict and Union annotations are supported. Everything else will raise an exception as unsupported. diff --git a/tests/fixtures/datatypes.py b/tests/fixtures/datatypes.py index b3ca0b4c3..7d1e7b605 100644 --- a/tests/fixtures/datatypes.py +++ b/tests/fixtures/datatypes.py @@ -17,7 +17,6 @@ def deserialize(self, value: Any, **kwargs: Any) -> Any: raise ConverterError() - def serialize(self, value: Telephone, **kwargs: Any) -> str: return "-".join(map(str, value)) diff --git a/tests/fixtures/series/show1.json b/tests/fixtures/series/samples/show1.json similarity index 100% rename from tests/fixtures/series/show1.json rename to tests/fixtures/series/samples/show1.json diff --git a/tests/fixtures/series/show2.json b/tests/fixtures/series/samples/show2.json similarity index 100% rename from tests/fixtures/series/show2.json rename to tests/fixtures/series/samples/show2.json diff --git a/tests/fixtures/stripe/.xsdata.xml b/tests/fixtures/stripe/.xsdata.xml new file mode 100644 index 000000000..b68df305b --- /dev/null +++ b/tests/fixtures/stripe/.xsdata.xml @@ -0,0 +1,20 @@ + + + + tests.fixtures.stripe.models.balance + dataclasses + single-package + reStructuredText + true + false + + + + + + + + + + + diff --git a/tests/fixtures/stripe/__init__.py b/tests/fixtures/stripe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fixtures/stripe/models/__init__.py b/tests/fixtures/stripe/models/__init__.py new file mode 100644 index 000000000..6ef63e4a2 --- /dev/null +++ b/tests/fixtures/stripe/models/__init__.py @@ -0,0 +1,15 @@ +from .balance import ( + Available, + Balance, + ConnectReserved, + Pending, + SourceTypes, +) + +__all__ = [ + "Available", + "Balance", + "ConnectReserved", + "Pending", + "SourceTypes", +] diff --git a/tests/fixtures/stripe/models/balance.py b/tests/fixtures/stripe/models/balance.py new file mode 100644 index 000000000..6cb4e543b --- /dev/null +++ b/tests/fixtures/stripe/models/balance.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass, field +from typing import Optional, Tuple + + +@dataclass(order=True, frozen=True) +class ConnectReserved: + class Meta: + name = "connect_reserved" + + amount: Optional[int] = field( + default=None, + metadata={ + "type": "Element", + } + ) + currency: Optional[str] = field( + default=None, + metadata={ + "type": "Element", + } + ) + + +@dataclass(order=True, frozen=True) +class SourceTypes: + class Meta: + name = "source_types" + + bank_account: Optional[int] = field( + default=None, + metadata={ + "type": "Element", + } + ) + card: Optional[int] = field( + default=None, + metadata={ + "type": "Element", + } + ) + + +@dataclass(order=True, frozen=True) +class Available: + class Meta: + name = "available" + + amount: Optional[int] = field( + default=None, + metadata={ + "type": "Element", + } + ) + currency: Optional[str] = field( + default=None, + metadata={ + "type": "Element", + } + ) + source_types: Optional[SourceTypes] = field( + default=None, + metadata={ + "type": "Element", + } + ) + + +@dataclass(order=True, frozen=True) +class Pending: + class Meta: + name = "pending" + + amount: Optional[int] = field( + default=None, + metadata={ + "type": "Element", + } + ) + currency: Optional[str] = field( + default=None, + metadata={ + "type": "Element", + } + ) + source_types: Optional[SourceTypes] = field( + default=None, + metadata={ + "type": "Element", + } + ) + + +@dataclass(order=True, frozen=True) +class Balance: + class Meta: + name = "balance" + + object_value: Optional[str] = field( + default=None, + metadata={ + "name": "object", + "type": "Element", + } + ) + available: Tuple[Available, ...] = field( + default_factory=tuple, + metadata={ + "type": "Element", + } + ) + connect_reserved: Tuple[ConnectReserved, ...] = field( + default_factory=tuple, + metadata={ + "type": "Element", + } + ) + livemode: Optional[bool] = field( + default=None, + metadata={ + "type": "Element", + } + ) + pending: Tuple[Pending, ...] = field( + default_factory=tuple, + metadata={ + "type": "Element", + } + ) diff --git a/tests/fixtures/stripe/samples/balance.json b/tests/fixtures/stripe/samples/balance.json new file mode 100644 index 000000000..dd7ef82a0 --- /dev/null +++ b/tests/fixtures/stripe/samples/balance.json @@ -0,0 +1,158 @@ +{ + "object": "balance", + "available": [ + { + "amount": 2685, + "currency": "nok", + "source_types": { + "bank_account": 0, + "card": 2685 + } + }, + { + "amount": 218420, + "currency": "nzd", + "source_types": { + "bank_account": 0, + "card": 218420 + } + }, + { + "amount": 779902, + "currency": "czk", + "source_types": { + "bank_account": 0, + "card": 779902 + } + }, + { + "amount": -1854, + "currency": "aud", + "source_types": { + "bank_account": 0, + "card": -1854 + } + }, + { + "amount": 278067892166, + "currency": "usd", + "source_types": { + "bank_account": 280201532, + "card": 277786138508 + } + }, + { + "amount": -54204, + "currency": "eur", + "source_types": { + "bank_account": 0, + "card": -54204 + } + }, + { + "amount": 2213741, + "currency": "cad", + "source_types": { + "bank_account": 0, + "card": 2213741 + } + }, + { + "amount": 7259805, + "currency": "gbp", + "source_types": { + "bank_account": 0, + "card": 7259805 + } + }, + { + "amount": -40320, + "currency": "jpy", + "source_types": { + "bank_account": 0, + "card": -40320 + } + }, + { + "amount": 12000, + "currency": "brl", + "source_types": { + "bank_account": 0, + "card": 12000 + } + }, + { + "amount": -412, + "currency": "sek", + "source_types": { + "bank_account": 0, + "card": -412 + } + } + ], + "connect_reserved": [ + { + "amount": 0, + "currency": "nok" + }, + { + "amount": 0, + "currency": "nzd" + }, + { + "amount": 0, + "currency": "czk" + }, + { + "amount": 0, + "currency": "aud" + }, + { + "amount": 55880, + "currency": "usd" + }, + { + "amount": 54584, + "currency": "eur" + }, + { + "amount": 0, + "currency": "cad" + }, + { + "amount": 0, + "currency": "gbp" + }, + { + "amount": 0, + "currency": "jpy" + }, + { + "amount": 0, + "currency": "brl" + }, + { + "amount": 0, + "currency": "sek" + } + ], + "livemode": false, + "pending": [ + { + "amount": 835408472, + "currency": "usd", + "source_types": { + "bank_account": 0, + "card": 835408472 + } + }, + { + "amount": -22251, + "currency": "eur", + "source_types": { + "bank_account": 0, + "card": -22251 + } + } + ] +} \ No newline at end of file diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index 58885c189..698c64492 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -252,6 +252,7 @@ def test_build_with_choice_field(self): class_field = fields(ChoiceType)[0] self.builder.parent_ns = "bar" + self.maxDiff = None actual = self.builder.build( 66, "choice", @@ -265,7 +266,7 @@ def test_build_with_choice_field(self): index=67, name="choice", types=(object,), - list_element=True, + factory=list, any_type=True, default=list, xml_type=XmlType.ELEMENTS, @@ -276,7 +277,7 @@ def test_build_with_choice_field(self): qname="{bar}a", types=(TypeA,), clazz=TypeA, - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}b": XmlVarFactory.create( @@ -285,7 +286,7 @@ def test_build_with_choice_field(self): qname="{bar}b", types=(TypeB,), clazz=TypeB, - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}int": XmlVarFactory.create( @@ -293,7 +294,7 @@ def test_build_with_choice_field(self): name="choice", qname="{bar}int", types=(int,), - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}int2": XmlVarFactory.create( @@ -303,7 +304,7 @@ def test_build_with_choice_field(self): types=(int,), derived=True, nillable=True, - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}float": XmlVarFactory.create( @@ -311,7 +312,7 @@ def test_build_with_choice_field(self): name="choice", qname="{bar}float", types=(float,), - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}qname": XmlVarFactory.create( @@ -319,7 +320,7 @@ def test_build_with_choice_field(self): name="choice", qname="{bar}qname", types=(QName,), - list_element=True, + factory=list, namespaces=("bar",), ), "{bar}tokens": XmlVarFactory.create( @@ -327,9 +328,9 @@ def test_build_with_choice_field(self): name="choice", qname="{bar}tokens", types=(int,), - tokens=True, + tokens_factory=list, derived=True, - list_element=True, + factory=list, default=return_true, namespaces=("bar",), ), @@ -339,7 +340,7 @@ def test_build_with_choice_field(self): qname="{foo}union", types=(UnionType,), clazz=UnionType, - list_element=True, + factory=list, namespaces=("foo",), ), "{bar}p": XmlVarFactory.create( @@ -348,7 +349,7 @@ def test_build_with_choice_field(self): qname="{bar}p", types=(float,), derived=True, - list_element=True, + factory=list, default=1.1, namespaces=("bar",), ), @@ -360,7 +361,7 @@ def test_build_with_choice_field(self): xml_type=XmlType.WILDCARD, qname="{http://www.w3.org/1999/xhtml}any", types=(object,), - list_element=True, + factory=list, default=None, namespaces=("http://www.w3.org/1999/xhtml",), ), diff --git a/tests/formats/dataclass/parsers/nodes/test_primitive.py b/tests/formats/dataclass/parsers/nodes/test_primitive.py index a06c8d8b9..f45f7ea40 100644 --- a/tests/formats/dataclass/parsers/nodes/test_primitive.py +++ b/tests/formats/dataclass/parsers/nodes/test_primitive.py @@ -24,7 +24,12 @@ def test_bind(self, mock_parse_value): self.assertEqual(("foo", 13), objects[-1]) mock_parse_value.assert_called_once_with( - "13", var.types, var.default, ns_map, var.tokens, var.format + value="13", + types=var.types, + default=var.default, + ns_map=ns_map, + tokens_factory=var.tokens_factory, + format=var.format, ) def test_bind_derived_mode(self): diff --git a/tests/formats/dataclass/parsers/test_utils.py b/tests/formats/dataclass/parsers/test_utils.py index 74003203d..c35f4c595 100644 --- a/tests/formats/dataclass/parsers/test_utils.py +++ b/tests/formats/dataclass/parsers/test_utils.py @@ -33,20 +33,20 @@ def test_parse_value(self, mock_deserialize): mock_deserialize.assert_called_once_with("1", [int], ns_map=None, format=None) def test_parse_value_with_tokens_true(self): - actual = ParserUtils.parse_value(" 1 2 3", [int], list, None, True) + actual = ParserUtils.parse_value(" 1 2 3", [int], list, None, list) self.assertEqual([1, 2, 3], actual) - actual = ParserUtils.parse_value(["1", "2", "3"], [int], list, None, True) - self.assertEqual([1, 2, 3], actual) + actual = ParserUtils.parse_value(["1", "2", "3"], [int], list, None, tuple) + self.assertEqual((1, 2, 3), actual) - actual = ParserUtils.parse_value(None, [int], lambda: [1, 2, 3], None, True) + actual = ParserUtils.parse_value(None, [int], lambda: [1, 2, 3], None, list) self.assertEqual([1, 2, 3], actual) @mock.patch.object(ConverterFactory, "deserialize", return_value=2) def test_parse_value_with_ns_map(self, mock_to_python): ns_map = dict(a=1) - ParserUtils.parse_value(" 1 2 3", [int], list, ns_map, True) - ParserUtils.parse_value(" 1 2 3", [str], None, ns_map, False) + ParserUtils.parse_value(" 1 2 3", [int], list, ns_map, list) + ParserUtils.parse_value(" 1 2 3", [str], None, ns_map) self.assertEqual(4, mock_to_python.call_count) mock_to_python.assert_has_calls( @@ -60,7 +60,7 @@ def test_parse_value_with_ns_map(self, mock_to_python): @mock.patch.object(ConverterFactory, "deserialize", return_value=2) def test_parse_value_with_format(self, mock_to_python): - ParserUtils.parse_value(" 1 2 3", [str], list, _format="Nope") + ParserUtils.parse_value(" 1 2 3", [str], list, format="Nope") self.assertEqual(1, mock_to_python.call_count) mock_to_python.assert_called_once_with( " 1 2 3", [str], ns_map=None, format="Nope" diff --git a/tests/formats/dataclass/serializers/test_xml.py b/tests/formats/dataclass/serializers/test_xml.py index 5bd430eb4..d16302183 100644 --- a/tests/formats/dataclass/serializers/test_xml.py +++ b/tests/formats/dataclass/serializers/test_xml.py @@ -119,7 +119,9 @@ def test_write_data(self): self.assertEqual(expected, list(result)) def test_write_tokens(self): - var = XmlVarFactory.create(xml_type=XmlType.ELEMENT, qname="a", tokens=True) + var = XmlVarFactory.create( + xml_type=XmlType.ELEMENT, qname="a", tokens_factory=list + ) result = self.serializer.write_value([], var, "xsdata") self.assertIsInstance(result, Generator) @@ -438,7 +440,7 @@ def test_write_choice_with_raw_value(self): qname="b", name="b", types=(int,), - tokens=True, + tokens_factory=list, ), }, ) @@ -475,9 +477,7 @@ def test_write_choice_when_no_matching_choice_exists(self): self.assertEqual(msg, str(cm.exception)) def test_write_value_with_list_value(self): - var = XmlVarFactory.create( - xml_type=XmlType.ELEMENT, qname="a", list_element=True - ) + var = XmlVarFactory.create(xml_type=XmlType.ELEMENT, qname="a", factory=list) value = [True, False] expected = [ (XmlWriterEvent.START, "a"), diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index 7e69492d5..b144aadf6 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -126,6 +126,13 @@ def test_field_default_value_with_type_tokens(self): self.assertEqual(expected, self.filters.field_default_value(attr)) + expected = """lambda: ( + 1, + "bar", + )""" + self.filters.format.frozen = True + self.assertEqual(expected, self.filters.field_default_value(attr)) + attr.tag = Tag.ENUMERATION expected = """( 1, @@ -205,11 +212,14 @@ def test_field_default_value_with_any_attribute(self): attr = AttrFactory.any_attribute() self.assertEqual("dict", self.filters.field_default_value(attr)) - def test_field_default_value_with_type_list(self): + def test_field_default_value_with_array_type(self): attr = AttrFactory.create(types=[type_bool]) attr.restrictions.max_occurs = 2 self.assertEqual("list", self.filters.field_default_value(attr)) + self.filters.format.frozen = True + self.assertEqual("tuple", self.filters.field_default_value(attr)) + def test_field_default_value_with_multiple_types(self): attr = AttrFactory.create(types=[type_bool, type_int, type_float], default="2") self.assertEqual("2", self.filters.field_default_value(attr)) @@ -386,7 +396,7 @@ def test_field_type_with_forward_and_circular_reference(self): self.filters.field_type(attr, ["Parent", "Inner"]), ) - def test_field_type_with_list_type(self): + def test_field_type_with_array_type(self): attr = AttrFactory.create( types=AttrTypeFactory.list(1, qname="foo_bar", forward=True) ) @@ -396,6 +406,12 @@ def test_field_type_with_list_type(self): self.filters.field_type(attr, ["A", "Parent"]), ) + self.filters.format.frozen = True + self.assertEqual( + 'Tuple["A.Parent.FooBar", ...]', + self.filters.field_type(attr, ["A", "Parent"]), + ) + def test_field_type_with_token_attr(self): attr = AttrFactory.create( types=AttrTypeFactory.list(1, qname="foo_bar"), @@ -406,6 +422,10 @@ def test_field_type_with_token_attr(self): attr.restrictions.max_occurs = 2 self.assertEqual("List[List[FooBar]]", self.filters.field_type(attr, [])) + attr.restrictions.max_occurs = 1 + self.filters.format.frozen = True + self.assertEqual("Tuple[FooBar, ...]", self.filters.field_type(attr, [])) + def test_field_type_with_alias(self): attr = AttrFactory.create( types=AttrTypeFactory.list( @@ -482,6 +502,10 @@ def test_choice_type_with_restrictions_tokens_true(self): actual = self.filters.choice_type(choice, ["a", "b"]) self.assertEqual("Type[List[Union[str, bool]]]", actual) + self.filters.format.frozen = True + actual = self.filters.choice_type(choice, ["a", "b"]) + self.assertEqual("Type[Tuple[Union[str, bool], ...]]", actual) + def test_default_imports_with_decimal(self): expected = "from decimal import Decimal" @@ -544,6 +568,10 @@ def test_default_imports_with_typing(self): expected = "from typing import List" self.assertIn(expected, self.filters.default_imports(output)) + output = ": Tuple[" + expected = "from typing import Tuple" + self.assertIn(expected, self.filters.default_imports(output)) + output = "Optional[ " expected = "from typing import Optional" self.assertIn(expected, self.filters.default_imports(output)) @@ -635,6 +663,29 @@ def test_import_module(self): for case in cases: self.assertEqual(case.result, transform(case.module, case.from_module)) + def test_build_class_annotation(self): + config = GeneratorConfig() + format = config.output.format + + actual = self.filters.build_class_annotation(format) + self.assertEqual("@dataclass", actual) + + format.frozen = True + actual = self.filters.build_class_annotation(format) + self.assertEqual("@dataclass(frozen=True)", actual) + + format.repr = False + format.eq = False + format.order = True + format.unsafe_hash = True + actual = self.filters.build_class_annotation(format) + expected = ( + "@dataclass(repr=False, eq=False, order=True," + " unsafe_hash=True, frozen=True)" + ) + + self.assertEqual(expected, actual) + def test__init(self): config = GeneratorConfig() config.conventions.package_name.safe_prefix = "safe_package" @@ -653,6 +704,8 @@ def test__init(self): filters = Filters(config) + self.assertFalse(filters.relative_imports) + self.assertEqual("safe_class", filters.class_safe_prefix) self.assertEqual("safe_field", filters.field_safe_prefix) self.assertEqual("safe_package", filters.package_safe_prefix) diff --git a/tests/formats/dataclass/test_typing.py b/tests/formats/dataclass/test_typing.py index ffa1b6c66..e20eff5cb 100644 --- a/tests/formats/dataclass/test_typing.py +++ b/tests/formats/dataclass/test_typing.py @@ -33,6 +33,14 @@ def test_get_origin_list(self): with self.assertRaises(TypeError): get_origin(List) + def test_get_origin_tuple(self): + self.assertEqual(Tuple, get_origin(Tuple[int])) + self.assertEqual(Tuple, get_origin(Tuple[Union[int, str]])) + self.assertEqual(Tuple, get_origin(Tuple[int, ...])) + + with self.assertRaises(TypeError): + get_origin(Tuple) + def test_get_origin_dict(self): self.assertEqual(Dict, get_origin(Dict)) self.assertEqual(Dict, get_origin(Dict[int, str])) @@ -81,6 +89,10 @@ def test_get_origin_unsupported(self): def test_get_args(self): self.assertEqual((), get_args(int)) self.assertEqual((int,), get_args(List[int])) + self.assertEqual((int, Ellipsis), get_args(Tuple[int, ...])) + self.assertEqual((int,), get_args(Tuple[int])) + self.assertEqual((int, str, float), get_args(Tuple[int, str, float])) + self.assertEqual((Union[str, int],), get_args(Tuple[Union[str, int]])) self.assertEqual((List[int], type(None)), get_args(Optional[List[int]])) self.assertEqual((int, str), get_args(Union[int, str])) self.assertEqual((int, type(None)), get_args(Optional[int])) @@ -122,6 +134,7 @@ def test_evaluate_list(self): self.assertEqual((list, int), evaluate(List[int])) self.assertEqual((list, float, str), evaluate(List[Union[float, str]])) self.assertEqual((list, int), evaluate(List[Optional[int]])) + self.assertEqual((list, tuple, int), evaluate(List[Tuple[int]])) self.assertEqual( (list, list, bool, str), evaluate(List[List[Union[bool, str]]]) ) @@ -131,6 +144,24 @@ def test_evaluate_list(self): with self.assertRaises(TypeError, msg=case): evaluate(case) + def test_evaluate_tuple(self): + A = TypeVar("A", int, str) + + self.assertEqual((tuple, int, str), evaluate(Tuple[A])) + self.assertEqual((tuple, int), evaluate(Tuple[int])) + self.assertEqual((tuple, int), evaluate(Tuple[int, ...])) + self.assertEqual((tuple, list, int), evaluate(Tuple[List[int], ...])) + self.assertEqual((tuple, float, str), evaluate(Tuple[Union[float, str]])) + self.assertEqual((tuple, int), evaluate(Tuple[Optional[int]])) + self.assertEqual( + (tuple, tuple, bool, str), evaluate(Tuple[Tuple[Union[bool, str]]]) + ) + + unsupported_cases = [Tuple, Tuple[Dict[str, str]]] + for case in unsupported_cases: + with self.assertRaises(TypeError, msg=case): + evaluate(case) + def test_evaluate_type(self): self.assertEqual((str,), evaluate(Type["str"])) diff --git a/tests/integration/test_series.py b/tests/integration/test_series.py index 1fe6f7f6b..69444e30a 100644 --- a/tests/integration/test_series.py +++ b/tests/integration/test_series.py @@ -17,7 +17,9 @@ def test_json_documents(): filepath = fixtures_dir.joinpath("series") package = "tests.fixtures.series" runner = CliRunner() - result = runner.invoke(cli, [str(filepath), "--package", package]) + result = runner.invoke( + cli, [str(filepath.joinpath("samples")), "--package", package] + ) if result.exception: raise result.exception @@ -28,7 +30,7 @@ def test_json_documents(): serializer = JsonSerializer(indent=4) for i in range(1, 3): - ori = filepath.joinpath(f"show{i}.json").read_text() + ori = filepath.joinpath(f"samples/show{i}.json").read_text() obj = parser.from_string(ori, clazz) actual = serializer.render(obj) diff --git a/tests/integration/test_stripe.py b/tests/integration/test_stripe.py new file mode 100644 index 000000000..26373a89c --- /dev/null +++ b/tests/integration/test_stripe.py @@ -0,0 +1,49 @@ +import json +import os + +from click.testing import CliRunner + +from tests import fixtures_dir +from tests import root +from xsdata.cli import cli +from xsdata.formats.dataclass.parsers import JsonParser +from xsdata.formats.dataclass.serializers import JsonSerializer +from xsdata.utils.testing import load_class + +os.chdir(root) + + +def test_json_documents(): + + filepath = fixtures_dir.joinpath("stripe") + package = "tests.fixtures.series" + runner = CliRunner() + result = runner.invoke( + cli, + [ + str(filepath.joinpath("samples")), + f"--config={str(filepath.joinpath('.xsdata.xml'))}", + ], + ) + + if result.exception: + raise result.exception + + clazz = load_class(result.output, "Balance") + + parser = JsonParser() + serializer = JsonSerializer(indent=4) + + for sample in filepath.joinpath("samples").glob("*.json"): + ori = sample.read_text() + obj = parser.from_string(ori, clazz) + actual = serializer.render(obj) + + assert filter_none(json.loads(ori)) == filter_none(json.loads(actual)) + + +def filter_none(d): + if isinstance(d, dict): + return {k: filter_none(v) for k, v in d.items() if v is not None} + else: + return d diff --git a/tests/models/test_config.py b/tests/models/test_config.py index bc72f960e..6f83a5e15 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -3,8 +3,10 @@ from unittest import TestCase from xsdata import __version__ +from xsdata.exceptions import GeneratorConfigError from xsdata.exceptions import ParserError from xsdata.models.config import GeneratorConfig +from xsdata.models.config import OutputFormat class GeneratorConfigTests(TestCase): @@ -22,9 +24,11 @@ def test_create(self): f'\n' ' \n' " generated\n" - ' dataclasses\n' + ' dataclasses\n' " filenames\n" " reStructuredText\n" + " false\n" " false\n" " \n" " \n" @@ -70,9 +74,11 @@ def test_read(self): f'\n' ' \n' " foo.bar\n" - ' dataclasses\n' + ' dataclasses\n' " filenames\n" " reStructuredText\n" + " false\n" " false\n" " \n" " \n" @@ -100,3 +106,10 @@ def test_read_with_wrong_value(self): file_path.write_text(existing, encoding="utf-8") with self.assertRaises(ParserError): GeneratorConfig.read(file_path) + + def test_format_with_invalid_state(self): + + with self.assertRaises(GeneratorConfigError) as cm: + OutputFormat(eq=False, order=True) + + self.assertEqual("eq must be true if order is true", str(cm.exception)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 1ae58b4d2..0d30b2248 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -46,7 +46,7 @@ def test_generate_with_default_output(self, mock_init, mock_process): self.assertFalse(mock_init.call_args[1]["print"]) self.assertEqual("foo", config.output.package) self.assertEqual("dataclasses", config.output.format.value) - self.assertFalse(config.output.format.relative_imports) + self.assertFalse(config.output.relative_imports) self.assertEqual(StructureStyle.FILENAMES, config.output.structure) self.assertEqual([source.as_uri()], mock_process.call_args[0][0]) @@ -142,7 +142,7 @@ def test_generate_with_configuration_file_and_overriding_args(self, mock_init, _ config = mock_init.call_args[1]["config"] self.assertEqual("foo", config.output.package) - self.assertTrue(config.output.format.relative_imports) + self.assertTrue(config.output.relative_imports) self.assertEqual(StructureStyle.NAMESPACES, config.output.structure) file_path.unlink() diff --git a/tests/utils/test_collections.py b/tests/utils/test_collections.py index 2e3b60657..fb899faa8 100644 --- a/tests/utils/test_collections.py +++ b/tests/utils/test_collections.py @@ -1,6 +1,4 @@ -import itertools -from typing import Generator -from typing import Hashable +from collections import namedtuple from unittest import TestCase from xsdata.utils import collections @@ -24,3 +22,11 @@ def test_remove(self): self.assertEqual([2, 2, 3], collections.remove([1, 2, 2, 3], lambda x: x == 1)) self.assertEqual([3], collections.remove([1, 2, 2, 3], lambda x: x < 3)) + + def test_is_array(self): + fixture = namedtuple("fixture", ["a", "b"]) + + self.assertFalse(collections.is_array(1)) + self.assertFalse(collections.is_array(fixture(1, 2))) + self.assertTrue(collections.is_array([])) + self.assertTrue(collections.is_array(tuple())) diff --git a/xsdata/cli.py b/xsdata/cli.py index 82238599b..6e3418e74 100644 --- a/xsdata/cli.py +++ b/xsdata/cli.py @@ -171,10 +171,9 @@ def generate(**kwargs: Any): config.output.package = kwargs["package"] else: config = GeneratorConfig() - config.output.format = OutputFormat( - value=kwargs["output"], relative_imports=kwargs["relative_imports"] - ) + config.output.format = OutputFormat(value=kwargs["output"]) config.output.package = kwargs["package"] + config.output.relative_imports = kwargs["relative_imports"] config.output.compound_fields = kwargs["compound_fields"] config.output.docstring_style = DocstringStyle(kwargs["docstring_style"]) @@ -185,7 +184,7 @@ def generate(**kwargs: Any): config.output.format.value = kwargs["output"] if kwargs["relative_imports"]: - config.output.format.relative_imports = True + config.output.relative_imports = True uris = resolve_source(kwargs["source"]) transformer = SchemaTransformer(config=config, print=kwargs["print"]) diff --git a/xsdata/exceptions.py b/xsdata/exceptions.py index 04f3a11e3..350218c7a 100644 --- a/xsdata/exceptions.py +++ b/xsdata/exceptions.py @@ -2,6 +2,10 @@ class CodeGenerationError(TypeError): """Unexpected state during code generation related errors.""" +class GeneratorConfigError(CodeGenerationError): + """Unexpected state during generator config related errors.""" + + class ConverterError(ValueError): """Converting values between document/python types related errors.""" diff --git a/xsdata/formats/converter.py b/xsdata/formats/converter.py index 70ecb6420..25b52ac61 100644 --- a/xsdata/formats/converter.py +++ b/xsdata/formats/converter.py @@ -31,6 +31,7 @@ from xsdata.models.datatype import XmlHexBinary from xsdata.models.datatype import XmlPeriod from xsdata.models.datatype import XmlTime +from xsdata.utils import collections from xsdata.utils import namespaces from xsdata.utils import text @@ -398,7 +399,7 @@ def deserialize( if data_type is None or not isinstance(data_type, EnumMeta): raise ConverterError(f"'{data_type}' is not an enum") - if isinstance(value, (list, tuple)) and not hasattr(value, "_fields"): + if collections.is_array(value): values = value elif isinstance(value, str): value = value.strip() diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index ca3952e95..072043440 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -22,8 +22,9 @@ from xsdata.models.config import DocstringStyle from xsdata.models.config import GeneratorAlias from xsdata.models.config import GeneratorConfig +from xsdata.models.config import OutputFormat +from xsdata.utils import collections from xsdata.utils import text -from xsdata.utils.collections import unique_sequence from xsdata.utils.namespaces import clean_uri from xsdata.utils.namespaces import local_name @@ -52,6 +53,7 @@ class Filters: "docstring_style", "max_line_length", "relative_imports", + "format", ) def __init__(self, config: GeneratorConfig): @@ -71,9 +73,16 @@ def __init__(self, config: GeneratorConfig): self.module_safe_prefix: str = config.conventions.module_name.safe_prefix self.docstring_style: DocstringStyle = config.output.docstring_style self.max_line_length: int = config.output.max_line_length - self.relative_imports: bool = config.output.format.relative_imports + self.relative_imports: bool = config.output.relative_imports + self.format = config.output.format def register(self, env: Environment): + env.globals.update( + { + "docstring_name": self.docstring_style.name.lower(), + "class_annotation": self.build_class_annotation(self.format), + } + ) env.filters.update( { "field_name": self.field_name, @@ -96,6 +105,23 @@ def register(self, env: Environment): } ) + @classmethod + def build_class_annotation(cls, format: OutputFormat) -> str: + args = [] + + if not format.repr: + args.append("repr=False") + if not format.eq: + args.append("eq=False") + if format.order: + args.append("order=True") + if format.unsafe_hash: + args.append("unsafe_hash=True") + if format.frozen: + args.append("frozen=True") + + return f"@dataclass({', '.join(args)})" if args else "@dataclass" + def class_name(self, name: str) -> str: """Convert the given string to a class name according to the selected conventions or use an existing alias.""" @@ -293,7 +319,7 @@ def format_metadata(self, data: Any, indent: int = 0, key: str = "") -> str: if isinstance(data, dict): return self.format_dict(data, indent) - if isinstance(data, (list, tuple)) and not hasattr(data, "_fields"): + if collections.is_array(data): return self.format_iterable(data, indent) if isinstance(data, str): @@ -436,7 +462,7 @@ def format_docstring(self, doc_string: str, level: int) -> str: def field_default_value(self, attr: Attr, ns_map: Optional[Dict] = None) -> Any: """Generate the field default value/factory for the given attribute.""" if attr.is_list or (attr.is_tokens and not attr.default): - return "list" + return "tuple" if self.format.frozen else "list" if attr.is_dict: return "dict" if attr.default is None: @@ -480,10 +506,11 @@ def field_default_tokens( assert isinstance(attr.default, str) fmt = attr.restrictions.format - tokens = [ + factory = tuple if self.format.frozen else list + tokens = factory( converter.deserialize(val, types, ns_map=ns_map, format=fmt) for val in attr.default.split() - ] + ) if attr.is_enumeration: return self.format_metadata(tuple(tokens), indent=8) @@ -493,7 +520,7 @@ def field_default_tokens( def field_type(self, attr: Attr, parents: List[str]) -> str: """Generate type hints for the given attribute.""" - type_names = unique_sequence( + type_names = collections.unique_sequence( self.field_type_name(x, parents) for x in attr.types ) @@ -502,10 +529,14 @@ def field_type(self, attr: Attr, parents: List[str]) -> str: result = f"Union[{result}]" if attr.is_tokens: - result = f"List[{result}]" + result = ( + f"Tuple[{result}, ...]" if self.format.frozen else f"List[{result}]" + ) if attr.is_list: - result = f"List[{result}]" + result = ( + f"Tuple[{result}, ...]" if self.format.frozen else f"List[{result}]" + ) elif attr.is_dict: result = "Dict[str, str]" elif attr.default is None and not attr.is_factory: @@ -524,7 +555,7 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: compound field that might be a list, that's why list restriction is also ignored. """ - type_names = unique_sequence( + type_names = collections.unique_sequence( self.field_type_name(x, parents) for x in choice.types ) @@ -533,7 +564,9 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: result = f"Union[{result}]" if choice.is_tokens: - result = f"List[{result}]" + result = ( + f"Tuple[{result}, ...]" if self.format.frozen else f"List[{result}]" + ) return f"Type[{result}]" @@ -590,6 +623,7 @@ def type_patterns(x: str) -> Tuple: "Dict": [": Dict"], "List": [": List["], "Optional": ["Optional["], + "Tuple": ["Tuple["], "Type": ["Type["], "Union": ["Union["], }, diff --git a/xsdata/formats/dataclass/generator.py b/xsdata/formats/dataclass/generator.py index ac53e7b15..f67ff4bcf 100644 --- a/xsdata/formats/dataclass/generator.py +++ b/xsdata/formats/dataclass/generator.py @@ -116,27 +116,21 @@ def render_classes( ) -> str: """Render the source code of the classes.""" load = self.env.get_template - config = self.config def render_class(obj: Class) -> str: """Render class or enumeration.""" if obj.is_enumeration: - template = "enum.jinja2" + template = load("enum.jinja2") elif obj.is_service: - template = "service.jinja2" + template = load("service.jinja2") else: - template = "class.jinja2" - - return ( - load(template) - .render( - obj=obj, - module_namespace=module_namespace, - docstring_style=config.output.docstring_style.name.lower(), - ) - .strip() - ) + template = load("class.jinja2") + + return template.render( + obj=obj, + module_namespace=module_namespace, + ).strip() return "\n\n\n".join(map(render_class, classes)) + "\n" diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index 4c14d41af..ce2950d9d 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -223,6 +223,7 @@ def build( init: bool, default_value: Any, globalns: Any, + factory: Optional[Callable] = None, ) -> Optional[XmlVar]: """Build the binding metadata for a dataclass field.""" xml_type = metadata.get("type", self.default_xml_type) @@ -246,7 +247,14 @@ def build( ) local_name = self.build_local_name(xml_type, local_name, name) - element_list = self.is_element_list(origin, sub_origin, tokens) + + if tokens and sub_origin is None: + sub_origin = origin + origin = None + + if origin is None: + origin = factory + any_type = self.is_any_type(types, xml_type) clazz = first(tp for tp in types if self.class_type.is_model(tp)) namespaces = self.resolve_namespaces(xml_type, namespace) @@ -255,7 +263,7 @@ def build( elements = {} wildcards = [] - for choice in self.build_choices(name, choices, globalns): + for choice in self.build_choices(name, choices, origin, globalns): if choice.is_element: elements[choice.qname] = choice else: # choice.is_wildcard: @@ -268,12 +276,12 @@ def build( init=init, mixed=mixed, format=format_str, - tokens=tokens, clazz=clazz, any_type=any_type, nillable=nillable, sequential=sequential, - list_element=element_list, + factory=origin, + tokens_factory=sub_origin, default=default_value, types=types, elements=elements, @@ -284,7 +292,7 @@ def build( ) def build_choices( - self, name: str, choices: List[Dict], globalns: Any + self, name: str, choices: List[Dict], factory: Callable, globalns: Any ) -> Iterator[XmlVar]: """Build the binding metadata for a compound dataclass field.""" existing_types: Set[type] = set() @@ -302,14 +310,19 @@ def build_choices( metadata["type"] = XmlType.ELEMENT var = self.build( - index, name, type_hint, metadata, True, default_value, globalns + index, + name, + type_hint, + metadata, + True, + default_value, + globalns, + factory, ) # It's impossible for choice elements to be ignorable, read above! assert var is not None - var.list_element = True - if var.any_type or any(True for tp in var.types if tp in existing_types): var.derived = True @@ -383,19 +396,6 @@ def default_namespace(cls, namespaces: Sequence[str]) -> Optional[str]: return None - @classmethod - def is_element_list(cls, origin: Any, sub_origin: Any, is_tokens: bool) -> bool: - """ - Return whether the field is a list element. - - If the field is derived from xs:NMTOKENS both origins have to be - lists. - """ - if origin is list: - return not is_tokens or sub_origin is list - - return False - @classmethod def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool: """Return whether the given xml type supports derived values.""" @@ -422,7 +422,7 @@ def analyze_types( origin = None sub_origin = None - while types[0] in (list, dict): + while types[0] in (tuple, list, dict): if origin is None: origin = types[0] elif sub_origin is None: @@ -455,7 +455,7 @@ def is_valid( # Attributes need origin dict, no sub origin and tokens if origin is not dict or sub_origin or tokens: return False - elif origin is dict or tokens and origin is not list: + elif origin is dict or tokens and origin not in (list, tuple): # Origin dict is only supported by Attributes # xs:NMTOKENS need origin list return False diff --git a/xsdata/formats/dataclass/models/elements.py b/xsdata/formats/dataclass/models/elements.py index 34dd6c6ec..5b1bd7a82 100644 --- a/xsdata/formats/dataclass/models/elements.py +++ b/xsdata/formats/dataclass/models/elements.py @@ -2,6 +2,7 @@ import operator import sys from typing import Any +from typing import Callable from typing import Dict from typing import Iterator from typing import List @@ -14,6 +15,7 @@ from xsdata.formats.converter import converter from xsdata.models.enums import NamespaceType +from xsdata.utils import collections from xsdata.utils.namespaces import local_name from xsdata.utils.namespaces import target_uri @@ -81,18 +83,20 @@ class XmlVar(MetaMixin): "clazz", "init", "mixed", - "tokens", + "factory", + "tokens_factory", "format", "derived", "any_type", "nillable", "sequential", - "list_element", "default", "namespaces", "elements", "wildcards", # Calculated + "tokens", + "list_element", "is_text", "is_element", "is_elements", @@ -113,13 +117,13 @@ def __init__( clazz: Optional[Type], init: bool, mixed: bool, - tokens: bool, + factory: Optional[Callable], + tokens_factory: Optional[Callable], format: Optional[str], derived: bool, any_type: bool, nillable: bool, sequential: bool, - list_element: bool, default: Any, xml_type: str, namespaces: Sequence[str], @@ -134,18 +138,21 @@ def __init__( self.clazz = clazz self.init = init self.mixed = mixed - self.tokens = tokens + self.tokens = tokens_factory is not None self.format = format self.derived = derived self.any_type = any_type self.nillable = nillable self.sequential = sequential - self.list_element = list_element + self.list_element = factory in (list, tuple) self.default = default self.namespaces = namespaces self.elements = elements self.wildcards = wildcards + self.factory = factory + self.tokens_factory = tokens_factory + self.namespace_matches: Optional[Dict[str, bool]] = None self.is_clazz_union = self.clazz and len(types) > 1 @@ -184,7 +191,7 @@ def find_value_choice(self, value: Any, check_subclass: bool) -> Optional["XmlVa """Match and return a choice field that matches the given value type.""" - if isinstance(value, list): + if collections.is_array(value): tp = type(None) if not value else type(value[0]) tokens = True check_subclass = False diff --git a/xsdata/formats/dataclass/parsers/json.py b/xsdata/formats/dataclass/parsers/json.py index a598b1ff2..0612b96af 100644 --- a/xsdata/formats/dataclass/parsers/json.py +++ b/xsdata/formats/dataclass/parsers/json.py @@ -23,6 +23,7 @@ from xsdata.formats.dataclass.parsers.utils import ParserUtils from xsdata.formats.dataclass.typing import get_args from xsdata.formats.dataclass.typing import get_origin +from xsdata.utils import collections from xsdata.utils.constants import EMPTY_MAP @@ -117,8 +118,8 @@ def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T: params = {} for key, value in data.items(): - is_list = isinstance(value, list) - var = self.find_var(xml_vars, key, is_list) + is_array = collections.is_array(value) + var = self.find_var(xml_vars, key, is_array) if var is None and self.config.fail_on_unknown_properties: raise ParserError(f"Unknown property {clazz.__qualname__}.{key}") @@ -202,7 +203,8 @@ def bind_value( # Repeating element, recursively bind the values if not recursive and var.list_element and isinstance(value, list): - return [self.bind_value(meta, var, val, True) for val in value] + assert var.factory is not None + return var.factory(self.bind_value(meta, var, val, True) for val in value) # If not dict this is an text or tokens value. if not isinstance(value, dict): @@ -241,7 +243,12 @@ def bind_text(self, meta: XmlMeta, var: XmlVar, value: Any) -> Any: # Convert value according to the field types return ParserUtils.parse_value( - value, var.types, var.default, EMPTY_MAP, var.tokens, var.format + value=value, + types=var.types, + default=var.default, + ns_map=EMPTY_MAP, + tokens_factory=var.tokens_factory, + format=var.format, ) def bind_complex_type(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: diff --git a/xsdata/formats/dataclass/parsers/nodes/element.py b/xsdata/formats/dataclass/parsers/nodes/element.py index 50b8ed286..80a3d586d 100644 --- a/xsdata/formats/dataclass/parsers/nodes/element.py +++ b/xsdata/formats/dataclass/parsers/nodes/element.py @@ -14,6 +14,7 @@ from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.parsers.mixins import XmlNode from xsdata.formats.dataclass.parsers.utils import ParserUtils +from xsdata.formats.dataclass.parsers.utils import PendingCollection from xsdata.logger import logger from xsdata.models.enums import DataType @@ -86,6 +87,10 @@ def bind( if not bind_text and wild_var: self.bind_wild_content(params, wild_var, text, tail) + for key in params.keys(): + if isinstance(params[key], PendingCollection): + params[key] = params[key].evaluate() + obj = self.meta.clazz(**params) if self.derived_factory: obj = self.derived_factory(qname=qname, value=obj, type=self.xsi_type) @@ -119,12 +124,12 @@ def bind_attrs(self, params: Dict): def bind_attr(self, params: Dict, var: XmlVar, value: Any): if var.init: params[var.name] = ParserUtils.parse_value( - value, - var.types, - var.default, - self.ns_map, - var.tokens, - var.format, + value=value, + types=var.types, + default=var.default, + ns_map=self.ns_map, + tokens_factory=var.tokens_factory, + format=var.format, ) def bind_any_attr(self, params: Dict, var: XmlVar, qname: str, value: Any): @@ -169,7 +174,7 @@ def bind_var(cls, params: Dict, var: XmlVar, value: Any) -> bool: if var.list_element: items = params.get(var.name) if items is None: - params[var.name] = [value] + params[var.name] = PendingCollection([value], var.factory) else: items.append(value) elif var.name not in params: @@ -194,7 +199,7 @@ def bind_wild_var(self, params: Dict, var: XmlVar, qname: str, value: Any) -> bo if var.list_element: items = params.get(var.name) if items is None: - params[var.name] = [value] + params[var.name] = PendingCollection([value], var.factory) else: items.append(value) elif var.name in params: @@ -246,7 +251,12 @@ def bind_content(self, params: Dict, txt: Optional[str]) -> bool: var = self.meta.text if var and var.init: params[var.name] = ParserUtils.parse_value( - txt, var.types, var.default, self.ns_map, var.tokens, var.format + value=txt, + types=var.types, + default=var.default, + ns_map=self.ns_map, + tokens_factory=var.tokens_factory, + format=var.format, ) return True @@ -272,7 +282,7 @@ def bind_wild_content( if var.list_element: items = params.get(var.name) if items is None: - params[var.name] = items = [] + params[var.name] = items = PendingCollection(None, var.factory) if txt: items.insert(0, txt) diff --git a/xsdata/formats/dataclass/parsers/nodes/primitive.py b/xsdata/formats/dataclass/parsers/nodes/primitive.py index c47ffc5c5..028be68bc 100644 --- a/xsdata/formats/dataclass/parsers/nodes/primitive.py +++ b/xsdata/formats/dataclass/parsers/nodes/primitive.py @@ -28,12 +28,12 @@ def bind( self, qname: str, text: Optional[str], tail: Optional[str], objects: List ) -> bool: obj = ParserUtils.parse_value( - text, - self.var.types, - self.var.default, - self.ns_map, - self.var.tokens, - self.var.format, + value=text, + types=self.var.types, + default=self.var.default, + ns_map=self.ns_map, + tokens_factory=self.var.tokens_factory, + format=self.var.format, ) if obj is None and not self.var.nillable: diff --git a/xsdata/formats/dataclass/parsers/nodes/standard.py b/xsdata/formats/dataclass/parsers/nodes/standard.py index 88bb10028..3542e41c5 100644 --- a/xsdata/formats/dataclass/parsers/nodes/standard.py +++ b/xsdata/formats/dataclass/parsers/nodes/standard.py @@ -38,12 +38,10 @@ def bind( self, qname: str, text: Optional[str], tail: Optional[str], objects: List ) -> bool: obj = ParserUtils.parse_value( - text, - [self.datatype.type], - None, - self.ns_map, - False, - self.datatype.format, + value=text, + types=[self.datatype.type], + ns_map=self.ns_map, + format=self.datatype.format, ) if obj is None and not self.nillable: diff --git a/xsdata/formats/dataclass/parsers/nodes/union.py b/xsdata/formats/dataclass/parsers/nodes/union.py index 27eda9f34..dd5e66c1d 100644 --- a/xsdata/formats/dataclass/parsers/nodes/union.py +++ b/xsdata/formats/dataclass/parsers/nodes/union.py @@ -104,6 +104,8 @@ def parse_value(self, value: Any, types: List[Type]) -> Any: try: with warnings.catch_warnings(): warnings.filterwarnings("error", category=ConverterWarning) - return ParserUtils.parse_value(value, types, ns_map=self.ns_map) + return ParserUtils.parse_value( + value=value, types=types, ns_map=self.ns_map + ) except Exception: return None diff --git a/xsdata/formats/dataclass/parsers/tree.py b/xsdata/formats/dataclass/parsers/tree.py index 2e063cc76..3e4134c83 100644 --- a/xsdata/formats/dataclass/parsers/tree.py +++ b/xsdata/formats/dataclass/parsers/tree.py @@ -47,13 +47,13 @@ def start( clazz=None, init=True, mixed=False, - tokens=False, + factory=None, + tokens_factory=None, format=None, derived=False, any_type=False, nillable=False, sequential=False, - list_element=False, default=None, namespaces=(), elements={}, diff --git a/xsdata/formats/dataclass/parsers/utils.py b/xsdata/formats/dataclass/parsers/utils.py index 9a6018579..8d4d97c5d 100644 --- a/xsdata/formats/dataclass/parsers/utils.py +++ b/xsdata/formats/dataclass/parsers/utils.py @@ -1,5 +1,8 @@ +from collections import UserList from typing import Any +from typing import Callable from typing import Dict +from typing import Iterable from typing import Optional from typing import Sequence from typing import Type @@ -7,11 +10,21 @@ from xsdata.formats.converter import converter from xsdata.formats.converter import QNameConverter from xsdata.models.enums import QNames +from xsdata.utils import collections from xsdata.utils import constants from xsdata.utils import text from xsdata.utils.namespaces import build_qname +class PendingCollection(UserList): + def __init__(self, initlist: Optional[Iterable], factory: Optional[Callable]): + super().__init__(initlist) + self.factory = factory or list + + def evaluate(self) -> Iterable: + return self.factory(self.data) + + class ParserUtils: @classmethod def xsi_type(cls, attrs: Dict, ns_map: Dict) -> Optional[str]: @@ -35,26 +48,26 @@ def parse_value( types: Sequence[Type], default: Any = None, ns_map: Optional[Dict] = None, - tokens: bool = False, - _format: Optional[str] = None, + tokens_factory: Callable = None, + format: Optional[str] = None, ) -> Any: """Convert xml string values to s python primitive type.""" if value is None: if callable(default): - return default() if tokens else None + return default() if tokens_factory else None return default - if tokens: - value = value if isinstance(value, list) else value.split() - return [ - converter.deserialize(val, types, ns_map=ns_map, format=_format) + if tokens_factory: + value = value if collections.is_array(value) else value.split() + return tokens_factory( + converter.deserialize(val, types, ns_map=ns_map, format=format) for val in value - ] + ) - return converter.deserialize(value, types, ns_map=ns_map, format=_format) + return converter.deserialize(value, types, ns_map=ns_map, format=format) @classmethod def normalize_content(cls, value: Optional[str]) -> Optional[str]: diff --git a/xsdata/formats/dataclass/serializers/json.py b/xsdata/formats/dataclass/serializers/json.py index d773d24f9..2389f911d 100644 --- a/xsdata/formats/dataclass/serializers/json.py +++ b/xsdata/formats/dataclass/serializers/json.py @@ -14,6 +14,7 @@ from xsdata.formats.converter import converter from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.models.elements import XmlVar +from xsdata.utils import collections def filter_none(x: Tuple) -> Dict: @@ -60,7 +61,7 @@ def write(self, out: TextIO, obj: Any): def convert(self, obj: Any, var: Optional[XmlVar] = None) -> Any: if var is None or self.context.class_type.is_model(obj): - if isinstance(obj, list): + if collections.is_array(obj): return [self.convert(o) for o in obj] return self.dict_factory( @@ -70,10 +71,7 @@ def convert(self, obj: Any, var: Optional[XmlVar] = None) -> Any: ] ) - if isinstance(obj, tuple) and hasattr(obj, "_fields"): - return converter.serialize(obj, format=var.format) - - if isinstance(obj, (list, tuple)): + if collections.is_array(obj): return type(obj)(self.convert(v, var) for v in obj) if isinstance(obj, (dict, int, float, str, bool)): diff --git a/xsdata/formats/dataclass/serializers/xml.py b/xsdata/formats/dataclass/serializers/xml.py index 95d3a911f..c6f321b96 100644 --- a/xsdata/formats/dataclass/serializers/xml.py +++ b/xsdata/formats/dataclass/serializers/xml.py @@ -5,6 +5,7 @@ from typing import Any from typing import Dict from typing import Generator +from typing import Iterable from typing import List from typing import Optional from typing import TextIO @@ -23,6 +24,7 @@ from xsdata.formats.dataclass.serializers.writers import default_writer from xsdata.models.enums import DataType from xsdata.models.enums import QNames +from xsdata.utils import collections from xsdata.utils import namespaces from xsdata.utils.constants import EMPTY_MAP from xsdata.utils.namespaces import split_qname @@ -141,12 +143,14 @@ def write_value(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: yield from self.write_tokens(value, var, namespace) elif var.is_elements: yield from self.write_elements(value, var, namespace) - elif var.list_element and isinstance(value, list): + elif var.list_element and collections.is_array(value): yield from self.write_list(value, var, namespace) else: yield from self.write_any_type(value, var, namespace) - def write_list(self, values: List, var: XmlVar, namespace: NoneStr) -> Generator: + def write_list( + self, values: Iterable, var: XmlVar, namespace: NoneStr + ) -> Generator: """Produce an events stream for the given list of values.""" for value in values: yield from self.write_value(value, var, namespace) @@ -155,7 +159,7 @@ def write_tokens(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator """Produce an events stream for the given tokens list or list of tokens lists.""" if value or var.nillable: - if value and isinstance(value[0], list): + if value and collections.is_array(value[0]): for val in value: yield from self.write_element(val, var, namespace) else: @@ -241,7 +245,7 @@ def xsi_type(self, var: XmlVar, value: Any, namespace: NoneStr) -> Optional[QNam def write_elements(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: """Produce an events stream from compound elements field.""" - if isinstance(value, list): + if collections.is_array(value): for choice in value: yield from self.write_choice(choice, var, namespace) else: @@ -350,7 +354,7 @@ def next_attribute( for var in meta.get_attribute_vars(): if var.is_attribute: value = getattr(obj, var.name) - if value is None or isinstance(value, list) and not value: + if value is None or collections.is_array(value) and not value: continue yield var.qname, cls.encode(value, var) diff --git a/xsdata/formats/dataclass/templates/class.jinja2 b/xsdata/formats/dataclass/templates/class.jinja2 index 202a36252..30fc5bb47 100644 --- a/xsdata/formats/dataclass/templates/class.jinja2 +++ b/xsdata/formats/dataclass/templates/class.jinja2 @@ -1,6 +1,6 @@ {% set level = level|default(0) -%} {% set help | format_docstring(level + 1) %} - {%- include "docstrings." + docstring_style + ".jinja2" -%} + {%- include "docstrings." + docstring_name + ".jinja2" -%} {% endset -%} {% set parent_namespace = obj.namespace if obj.namespace is not none else parent_namespace|default(None) -%} {% set parents = parents|default([obj.name]) -%} @@ -10,7 +10,7 @@ {% set base_classes = obj.extensions|map(attribute='type')|map('type_name')|join(', ') -%} {% set target_namespace = obj.target_namespace if level == 0 and module_namespace != obj.target_namespace else None %} -@dataclass +{{ class_annotation }} class {{ class_name }}{{"({})".format(base_classes) if base_classes }}: {%- if help %} {{ help|indent(4, first=True) }} diff --git a/xsdata/formats/dataclass/templates/enum.jinja2 b/xsdata/formats/dataclass/templates/enum.jinja2 index 59d1c627d..1c32de56c 100644 --- a/xsdata/formats/dataclass/templates/enum.jinja2 +++ b/xsdata/formats/dataclass/templates/enum.jinja2 @@ -1,6 +1,6 @@ {% set level = level | default(0) -%} {% set help | format_docstring(level + 1) %} - {%- include "docstrings." + docstring_style + ".jinja2" -%} + {%- include "docstrings." + docstring_name + ".jinja2" -%} {% endset -%} {% set class_name = obj.name | class_name %} @@ -11,7 +11,7 @@ class {{ class_name }}(Enum): {%- for attr in obj.attrs %} {{ attr.name | constant_name(obj.name) }} = {{ attr | field_default(obj.ns_map) }} {%- endfor -%} -{% if docstring_style == "accessible" -%} +{% if docstring_name == "accessible" -%} {{ "\n\n" if level == 0 else "\n" }} {%- for attr in obj.attrs if attr.help %} {% set member_name = "{}.{}.__doc__ = ".format(class_name, attr.name | constant_name(obj.name)) -%} diff --git a/xsdata/formats/dataclass/typing.py b/xsdata/formats/dataclass/typing.py index d98fbe791..bf236a759 100644 --- a/xsdata/formats/dataclass/typing.py +++ b/xsdata/formats/dataclass/typing.py @@ -17,7 +17,7 @@ def get_origin(tp: Any) -> Any: if tp is Dict: return Dict - if tp in (List, Union): + if tp in (Tuple, List, Union): raise TypeError() if isinstance(tp, TypeVar): @@ -28,6 +28,9 @@ def get_origin(tp: Any) -> Any: if origin in (list, List): return List + if origin in (tuple, Tuple): + return Tuple + if origin in (dict, Dict): return Dict @@ -53,23 +56,21 @@ def evaluate(tp: Any, globalns: Any = None, localns: Any = None) -> Tuple[Type, def _evaluate(tp: Any) -> Iterator[Type]: origin = get_origin(tp) - if origin is List: - yield from _evaluate_list(tp) - elif origin is Dict: - yield from _evaluate_mapping(tp) - elif origin is Union: - yield from _evaluate_union(tp) - elif origin is Type: - args = get_args(tp) - if not args or isinstance(args[0], TypeVar): - raise TypeError() - yield from _evaluate(args[0]) - elif origin is TypeVar: - yield from _evaluate_typevar(tp) + + func = __evaluations__.get(origin) + if func: + yield from func(tp) else: yield tp +def _evaluate_type(tp: Any) -> Iterator[Type]: + args = get_args(tp) + if not args or isinstance(args[0], TypeVar): + raise TypeError() + yield from _evaluate(args[0]) + + def _evaluate_mapping(tp: Any) -> Iterator[Type]: yield dict args = get_args(tp) @@ -102,12 +103,26 @@ def _evaluate_list(tp: Any) -> Iterator[Type]: if origin is None: yield arg - elif origin is Union: - yield from _evaluate_union(arg) - elif origin is List: - yield from _evaluate_list(arg) - elif origin is TypeVar: - yield from _evaluate_typevar(arg) + elif origin in (Union, List, Tuple, TypeVar): + yield from __evaluations__[origin](arg) + else: + raise TypeError() + + +def _evaluate_tuple(tp: Any) -> Iterator[Type]: + yield tuple + + args = get_args(tp) + for arg in args: + + if arg is Ellipsis: + continue + + origin = get_origin(arg) + if origin is None: + yield arg + elif origin in (Union, List, Tuple, TypeVar): + yield from __evaluations__[origin](arg) else: raise TypeError() @@ -136,3 +151,13 @@ def _evaluate_typevar(tp: TypeVar): yield from _evaluate(arg) else: raise TypeError() + + +__evaluations__ = { + Tuple: _evaluate_tuple, + List: _evaluate_list, + Dict: _evaluate_mapping, + Union: _evaluate_union, + Type: _evaluate_type, + TypeVar: _evaluate_typevar, +} diff --git a/xsdata/formats/dataclass/utils.py b/xsdata/formats/dataclass/utils.py index 4d168859f..1f45b084b 100644 --- a/xsdata/formats/dataclass/utils.py +++ b/xsdata/formats/dataclass/utils.py @@ -15,6 +15,7 @@ "false", "none", "yield", + "object", "break", "for", "not", diff --git a/xsdata/models/config.py b/xsdata/models/config.py index 0e7e57513..d23a3dd9b 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -9,6 +9,7 @@ from typing import TextIO from xsdata import __version__ +from xsdata.exceptions import GeneratorConfigError from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.parsers import XmlParser from xsdata.formats.dataclass.parsers.config import ParserConfig @@ -129,11 +130,23 @@ class OutputFormat: Output format options. :param value: Name of the format - :param relative_imports: Enable relative imports + :param repr: Generate repr methods + :param eq: Generate equal method + :param order: Generate rich comparison methods + :param unsafe_hash: Generate hash method when frozen is false + :param frozen: Enable read only properties with immutable containers """ value: str = field(default="dataclasses") - relative_imports: bool = attribute(default=False) + repr: bool = attribute(default=True) + eq: bool = attribute(default=True) + order: bool = attribute(default=False) + unsafe_hash: bool = attribute(default=False) + frozen: bool = attribute(default=False) + + def __post_init__(self): + if self.order and not self.eq: + raise GeneratorConfigError("eq must be true if order is true") @dataclass @@ -146,6 +159,7 @@ class GeneratorOutput: :param format: Code generator output format name :param structure: Select an output structure :param docstring_style: Select a docstring style + :param relative_imports: Enable relative imports :param compound_fields: Use compound fields for repeating choices. Enable if elements ordering matters for your case. """ @@ -155,6 +169,7 @@ class GeneratorOutput: format: OutputFormat = element(default_factory=OutputFormat) structure: StructureStyle = element(default=StructureStyle.FILENAMES) docstring_style: DocstringStyle = element(default=DocstringStyle.RST) + relative_imports: bool = element(default=False) compound_fields: bool = element(default=False) diff --git a/xsdata/utils/collections.py b/xsdata/utils/collections.py index 2939a95f0..9030105e0 100644 --- a/xsdata/utils/collections.py +++ b/xsdata/utils/collections.py @@ -9,6 +9,13 @@ from typing import Sequence +def is_array(value: Any) -> bool: + if isinstance(value, (list, tuple)): + return not hasattr(value, "_fields") + + return False + + def unique_sequence(items: Iterable, key: Optional[str] = None) -> List: """ Return a new list with the unique values from an iterable. diff --git a/xsdata/utils/testing.py b/xsdata/utils/testing.py index 3be78204a..e227e9211 100644 --- a/xsdata/utils/testing.py +++ b/xsdata/utils/testing.py @@ -5,6 +5,7 @@ import unittest from dataclasses import is_dataclass from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -344,7 +345,8 @@ def create( clazz: Optional[Type] = None, init: bool = True, mixed: bool = False, - tokens: bool = False, + factory: Optional[Callable] = None, + tokens_factory: Optional[Callable] = None, format: Optional[str] = None, derived: bool = False, any_type: bool = False, @@ -381,7 +383,8 @@ def create( clazz=clazz or first(tp for tp in types if is_dataclass(tp)), init=init, mixed=mixed, - tokens=tokens, + factory=factory, + tokens_factory=tokens_factory, format=format, derived=derived, any_type=any_type,