Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disambiguate dataclasses too #477

Merged
merged 1 commit into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Our backwards-compatibility policy can be found [here](https://github.com/python
([#432](https://github.com/python-attrs/cattrs/issues/432) [#472](https://github.com/python-attrs/cattrs/pull/472))
- The default union handler now properly takes renamed fields into account.
([#472](https://github.com/python-attrs/cattrs/pull/472))
- The default union handler now also handles dataclasses.
([#](https://github.com/python-attrs/cattrs/pull/))
- Add support for [PEP 695](https://peps.python.org/pep-0695/) type aliases.
([#452](https://github.com/python-attrs/cattrs/pull/452))
- The `include_subclasses` strategy now fetches the member hooks from the converter (making use of converter defaults) if overrides are not provided, instead of generating new hooks with no overrides.
Expand Down
12 changes: 8 additions & 4 deletions docs/unions.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Handling Unions

_cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy).
_cattrs_ is able to handle simple unions of _attrs_ classes and dataclasses [automatically](#default-union-strategy).
More complex cases require converter customization (since there are many ways of handling unions).

_cattrs_ also comes with a number of strategies to help handle unions:
_cattrs_ also comes with a number of optional strategies to help handle unions:

- [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below
- [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters
Expand All @@ -12,10 +12,10 @@ _cattrs_ also comes with a number of strategies to help handle unions:

For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated.

Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways.
Given a union of several _attrs_ classes and/or dataclasses, the default union strategy will attempt to handle it in several ways.

First, it will look for `Literal` fields.
If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.
If _all members_ of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.

```python
from typing import Literal
Expand Down Expand Up @@ -68,6 +68,10 @@ The field `field_with_default` will not be considered since it has a default val
Literals can now be potentially used to disambiguate.
```

```{versionchanged} 24.1.0
Dataclasses are now supported in addition to _attrs_ classes.
```

## Unstructuring Unions with Extra Metadata

```{note}
Expand Down
12 changes: 11 additions & 1 deletion src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import deque
from collections.abc import MutableSet as AbcMutableSet
from collections.abc import Set as AbcSet
from dataclasses import MISSING, is_dataclass
from dataclasses import MISSING, Field, is_dataclass
from dataclasses import fields as dataclass_fields
from typing import AbstractSet as TypingAbstractSet
from typing import (
Expand All @@ -18,6 +18,7 @@
Protocol,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
Expand All @@ -31,9 +32,11 @@

from attrs import NOTHING, Attribute, Factory, resolve_types
from attrs import fields as attrs_fields
from attrs import fields_dict as attrs_fields_dict

__all__ = [
"adapted_fields",
"fields_dict",
"ExceptionGroup",
"ExtensionsTypedDict",
"get_type_alias_base",
Expand Down Expand Up @@ -119,6 +122,13 @@ def fields(type):
raise Exception("Not an attrs or dataclass class.") from None


def fields_dict(type) -> Dict[str, Union[Attribute, Field]]:
"""Return the fields_dict for attrs and dataclasses."""
if is_dataclass(type):
return {f.name: f for f in dataclass_fields(type)}
return attrs_fields_dict(type)


def adapted_fields(cl) -> List[Attribute]:
"""Return the attrs format of `fields()` for attrs and dataclasses."""
if is_dataclass(cl):
Expand Down
35 changes: 26 additions & 9 deletions src/cattrs/disambiguators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import MISSING
from functools import reduce
from operator import or_
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Union

from attrs import NOTHING, Attribute, AttrsInstance, fields, fields_dict

from ._compat import NoneType, get_args, get_origin, has, is_literal, is_union_type
from attrs import NOTHING, Attribute, AttrsInstance

from ._compat import (
NoneType,
adapted_fields,
fields_dict,
get_args,
get_origin,
has,
is_literal,
is_union_type,
)
from .gen import AttributeOverride

if TYPE_CHECKING:
Expand All @@ -31,13 +41,16 @@ def create_default_dis_func(
overrides: dict[str, AttributeOverride]
| Literal["from_converter"] = "from_converter",
) -> Callable[[Mapping[Any, Any]], type[Any] | None]:
"""Given attrs classes, generate a disambiguation function.
"""Given attrs classes or dataclasses, generate a disambiguation function.

The function is based on unique fields without defaults or unique values.

:param use_literals: Whether to try using fields annotated as literals for
disambiguation.
:param overrides: Attribute overrides to apply.

.. versionchanged:: 24.1.0
Dataclasses are now supported.
"""
if len(classes) < 2:
raise ValueError("At least two classes required.")
Expand All @@ -55,7 +68,11 @@ def create_default_dis_func(
# (... TODO: a single fallback is OK)
# - it must always be enumerated
cls_candidates = [
{at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)}
{
at.name
for at in adapted_fields(get_origin(cl) or cl)
if is_literal(at.type)
}
for cl in classes
]

Expand Down Expand Up @@ -128,10 +145,10 @@ def dis_func(data: Mapping[Any, Any]) -> type | None:
uniq = cl_reqs - other_reqs

# We want a unique attribute with no default.
cl_fields = fields(get_origin(cl) or cl)
cl_fields = fields_dict(get_origin(cl) or cl)
for maybe_renamed_attr_name in uniq:
orig_name = back_map[maybe_renamed_attr_name]
if getattr(cl_fields, orig_name).default is NOTHING:
if cl_fields[orig_name].default in (NOTHING, MISSING):
break
else:
if fallback is None:
Expand Down Expand Up @@ -173,13 +190,13 @@ def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str:


def _usable_attribute_names(
cl: type[AttrsInstance], overrides: dict[str, AttributeOverride]
cl: type[Any], overrides: dict[str, AttributeOverride]
) -> tuple[set[str], dict[str, str]]:
"""Return renamed fields and a mapping to original field names."""
res = set()
mapping = {}

for at in fields(get_origin(cl) or cl):
for at in adapted_fields(get_origin(cl) or cl):
res.add(n := _overriden_name(at, overrides.get(at.name)))
mapping[n] = at.name

Expand Down
49 changes: 38 additions & 11 deletions tests/test_disambiguators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for auto-disambiguators."""
from dataclasses import dataclass
from functools import partial
from typing import Literal, Union

Expand All @@ -7,11 +8,7 @@
from hypothesis import HealthCheck, assume, given, settings

from cattrs import Converter
from cattrs.disambiguators import (
create_default_dis_func,
create_uniq_field_dis_func,
is_supported_union,
)
from cattrs.disambiguators import create_default_dis_func, is_supported_union
from cattrs.gen import make_dict_structure_fn, override

from .untyped import simple_classes
Expand All @@ -27,7 +24,7 @@ class A:

with pytest.raises(ValueError):
# Can't generate for only one class.
create_uniq_field_dis_func(c, A)
create_default_dis_func(c, A)

with pytest.raises(ValueError):
create_default_dis_func(c, A)
Expand All @@ -38,7 +35,7 @@ class B:

with pytest.raises(TypeError):
# No fields on either class.
create_uniq_field_dis_func(c, A, B)
create_default_dis_func(c, A, B)

@define
class C:
Expand All @@ -50,7 +47,7 @@ class D:

with pytest.raises(TypeError):
# No unique fields on either class.
create_uniq_field_dis_func(c, C, D)
create_default_dis_func(c, C, D)

with pytest.raises(TypeError):
# No discriminator candidates
Expand All @@ -66,7 +63,7 @@ class F:

with pytest.raises(TypeError):
# no usable non-default attributes
create_uniq_field_dis_func(c, E, F)
create_default_dis_func(c, E, F)

@define
class G:
Expand All @@ -93,7 +90,7 @@ def test_fallback(cl_and_vals):
class A:
pass

fn = create_uniq_field_dis_func(c, A, cl)
fn = create_default_dis_func(c, A, cl)

assert fn({}) is A
assert fn(asdict(cl(*vals, **kwargs))) is cl
Expand Down Expand Up @@ -124,7 +121,7 @@ def test_disambiguation(cl_and_vals_a, cl_and_vals_b):
for attr_name in req_b - req_a:
assume(getattr(fields(cl_b), attr_name).default is NOTHING)

fn = create_uniq_field_dis_func(c, cl_a, cl_b)
fn = create_default_dis_func(c, cl_a, cl_b)

assert fn(asdict(cl_a(*vals_a, **kwargs_a))) is cl_a

Expand Down Expand Up @@ -271,3 +268,33 @@ class B:

assert converter.structure({"a": 1}, Union[A, B]) == A(1)
assert converter.structure({"b": 1}, Union[A, B]) == B(1)


def test_dataclasses(converter):
"""The default strategy works for dataclasses too."""

@define
class A:
a: int

@dataclass
class B:
b: int

assert converter.structure({"a": 1}, Union[A, B]) == A(1)
assert converter.structure({"b": 1}, Union[A, B]) == B(1)


def test_dataclasses_literals(converter):
"""The default strategy works for dataclasses too."""

@define
class A:
a: Literal["a"] = "a"

@dataclass
class B:
b: Literal["b"]

assert converter.structure({"a": "a"}, Union[A, B]) == A()
assert converter.structure({"b": "b"}, Union[A, B]) == B("b")