diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index 32715c88f..58e00f232 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -128,6 +128,8 @@ def register(self, env: Environment): { "field_name": self.field_name, "field_type": self.field_type, + "field_type_names": self.field_type_names, + "field_type_names_join": self.field_type_names_join, "field_default": self.field_default_value, "field_metadata": self.field_metadata, "field_definition": self.field_definition, @@ -823,6 +825,23 @@ def choice_type(self, obj: Class, choice: Attr) -> str: return f"Type[{result}]" + def field_type_names( + self, + obj: Class, + attr: Attr, + choice: bool = False, + ) -> List[str]: + return [self._field_type_name(obj, x, choice=choice) for x in attr.types] + + def field_type_names_join( + self, + obj: Class, + attr: Attr, + choice: bool = False, + ) -> str: + type_names = [self._field_type_name(obj, x, choice=choice) for x in attr.types] + return self._join_type_names(type_names) + def _field_type_names( self, obj: Class, diff --git a/xsdata/formats/dataclass/templates/class.jinja2 b/xsdata/formats/dataclass/templates/class.jinja2 index 86249b93a..127ec4c46 100644 --- a/xsdata/formats/dataclass/templates/class.jinja2 +++ b/xsdata/formats/dataclass/templates/class.jinja2 @@ -43,7 +43,26 @@ class {{ class_name }}{{"({})".format(base_classes) if base_classes }}: {%- for attr in obj.attrs %} {%- set field_typing = obj|field_type(attr) %} {%- set field_definition = obj|field_definition(attr, parent_namespace) %} - {{ attr.name|field_name(obj.name) }}: {{ field_typing }} = {{ field_definition }} + {%- set field_name = attr.name|field_name(obj.name) %} + {{ field_name }}: {{ field_typing }} = {{ field_definition }} +{%- for attr2 in attr.choices %} + {%- set field_type_names = obj|field_type_names(attr2) %} + {# TODO: how to get the string without quotes? #} + {%- set field_type_names_fmt = obj|field_type_names_join(attr2)|replace("\"", "") %} + {%- set field_name2 = attr2.name|field_name(obj.name) %} + {% if 'Any' not in field_typing %} + @property + def {{ field_name2 }}(self) -> {{ field_type_names_fmt }}: + {%- for field_type_name in field_type_names %} + {# TODO: how to get the string without quotes? #} + if isinstance(self.{{ field_name }}, {{ field_type_name|replace("\"", "") }}): + return self.{{ field_name }} + {% endfor %} + @{{ field_name2 }}.setter + def {{ field_name2 }}(self, value: {{ field_type_names_fmt|safe }}) -> None: + self.{{ field_name }} = value + {% endif %} +{% endfor %} {%- endfor -%} {%- for inner in obj.inner %} {%- set tpl = "enum.jinja2" if inner.is_enumeration else "class.jinja2" -%}