Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TypedDict coverage #450

Merged
merged 11 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:

steps:
- uses: "actions/checkout@v3"

- uses: "actions/setup-python@v4"
with:
cache: "pip"
Expand All @@ -71,12 +71,15 @@ jobs:
export TOTAL=$(python -c "import json;print(json.load(open('coverage.json'))['totals']['percent_covered_display'])")
echo "total=$TOTAL" >> $GITHUB_ENV

# Report again and fail if under the threshold.
python -Im coverage report --fail-under=97

- name: "Upload HTML report."
uses: "actions/upload-artifact@v3"
with:
name: "html-report"
path: "htmlcov"

- name: "Make badge"
if: github.ref == 'refs/heads/main'
uses: "schneegans/[email protected]"
Expand Down
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# History

## 24.1.0 (UNRELEASED)

- More robust support for `Annotated` and `NotRequired` in TypedDicts.
([#450](https://github.com/python-attrs/cattrs/pull/450))

## 23.2.1 (2023-11-18)

- Fix unnecessary `typing_extensions` import on Python 3.11.
Expand Down
11 changes: 8 additions & 3 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import AbstractSet as TypingAbstractSet
from typing import Any, Deque, Dict, Final, FrozenSet, List
from typing import Any, Deque, Dict, Final, FrozenSet, List, Literal
from typing import Mapping as TypingMapping
from typing import MutableMapping as TypingMutableMapping
from typing import MutableSequence as TypingMutableSequence
Expand Down Expand Up @@ -243,6 +243,9 @@ def get_newtype_base(typ: Any) -> Optional[type]:
return None

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if is_annotated(type):
# Handle `Annotated[NotRequired[int]]`
type = get_args(type)[0]
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
return NOTHING
Expand Down Expand Up @@ -438,8 +441,6 @@ def is_counter(type):
or getattr(type, "__origin__", None) is ColCounter
)

from typing import Literal

def is_literal(type) -> bool:
return type.__class__ is _GenericAlias and type.__origin__ is Literal

Expand All @@ -453,6 +454,10 @@ def copy_with(type, args):
return type.copy_with(args)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if is_annotated(type):
# Handle `Annotated[NotRequired[int]]`
type = get_origin(type)

if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
return NOTHING
Expand Down
10 changes: 8 additions & 2 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@
is_union_type,
)
from .disambiguators import create_default_dis_func, is_supported_union
from .dispatch import HookFactory, MultiStrategyDispatch, StructureHook, UnstructureHook
from .dispatch import (
HookFactory,
MultiStrategyDispatch,
StructureHook,
UnstructuredValue,
UnstructureHook,
)
from .errors import (
IterableValidationError,
IterableValidationNote,
Expand Down Expand Up @@ -327,7 +333,7 @@ def register_structure_hook_factory(
"""
self._structure_func.register_func_list([(predicate, factory, True)])

def structure(self, obj: Any, cl: Type[T]) -> T:
def structure(self, obj: UnstructuredValue, cl: Type[T]) -> T:
"""Convert unstructured Python data structures to structured data."""
return self._structure_func.dispatch(cl)(obj, cl)

Expand Down
126 changes: 60 additions & 66 deletions src/cattrs/gen/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,20 @@ def make_dict_unstructure_fn(
break
handler = None
t = a.type
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb

if isinstance(t, TypeVar):
if t.__name__ in mapping:
t = mapping[t.__name__]
else:
# Unbound typevars use late binding.
handler = converter.unstructure
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)

if handler is None:
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
try:
handler = converter._unstructure_func.dispatch(t)
except RecursionError:
Expand Down Expand Up @@ -171,9 +172,6 @@ def make_dict_unstructure_fn(
handler = override.unstruct_hook
else:
t = a.type
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb

if isinstance(t, TypeVar):
if t.__name__ in mapping:
Expand All @@ -184,6 +182,9 @@ def make_dict_unstructure_fn(
t = deep_copy_with(t, mapping)

if handler is None:
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
try:
handler = converter._unstructure_func.dispatch(t)
except RecursionError:
Expand Down Expand Up @@ -282,9 +283,6 @@ def make_dict_structure_fn(
mapping = generate_mapping(base, mapping)
break

if isinstance(cl, TypeVar):
cl = mapping.get(cl.__name__, cl)

cl_name = cl.__name__
fn_name = "structure_" + cl_name

Expand Down Expand Up @@ -337,6 +335,12 @@ def make_dict_structure_fn(
if override.omit:
continue
t = a.type

if isinstance(t, TypeVar):
t = mapping.get(t.__name__, t)
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)

nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
Expand Down Expand Up @@ -370,16 +374,11 @@ def make_dict_structure_fn(
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t

if handler:
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'])")
else:
lines.append(
f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'])")
else:
lines.append(f"{i}res['{an}'] = o['{kn}']")
lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})")
if override.rename is not None:
lines.append(f"{i}del res['{kn}']")
i = i[:-2]
Expand Down Expand Up @@ -415,42 +414,38 @@ def make_dict_structure_fn(
continue

t = a.type
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb

if isinstance(t, TypeVar):
t = mapping.get(t.__name__, t)
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)

# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
if t is not None:
handler = converter._structure_func.dispatch(t)
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb

if override.struct_hook is not None:
handler = override.struct_hook
else:
handler = converter.structure
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
handler = converter._structure_func.dispatch(t)

kn = an if override.rename is None else override.rename
allowed_fields.add(kn)

if handler:
struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
invocation_line = (
f" res['{an}'] = {struct_handler_name}(o['{kn}'])"
)
else:
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
invocation_line = (
f" res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
invocation_line = f" res['{an}'] = {struct_handler_name}(o['{kn}'])"
else:
invocation_line = f" res['{an}'] = o['{kn}']"
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
invocation_line = (
f" res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})"
)

lines.append(invocation_line)
if override.rename is not None:
Expand All @@ -472,13 +467,13 @@ def make_dict_structure_fn(
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)

# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
if t is not None:
handler = converter._structure_func.dispatch(t)
if override.struct_hook is not None:
handler = override.struct_hook
else:
handler = converter.structure
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
handler = converter._structure_func.dispatch(t)

struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
Expand All @@ -487,20 +482,17 @@ def make_dict_structure_fn(
kn = an if override.rename is None else override.rename
allowed_fields.add(kn)
post_lines.append(f" if '{kn}' in o:")
if handler:
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'])"
)
else:
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'])"
)
else:
post_lines.append(f" res['{ian}'] = o['{kn}']")
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
if override.rename is not None:
lines.append(f" res.pop('{override.rename}', None)")

Expand Down Expand Up @@ -568,6 +560,7 @@ def _required_keys(cls: type) -> set[str]:
from typing_extensions import Annotated, NotRequired, Required, get_args

def _required_keys(cls: type) -> set[str]:
"""Own own processor for required keys."""
if _is_extensions_typeddict(cls):
return cls.__required_keys__

Expand Down Expand Up @@ -600,6 +593,7 @@ def _required_keys(cls: type) -> set[str]:
# On 3.8, typing.TypedDicts do not have __required_keys__.

def _required_keys(cls: type) -> set[str]:
"""Own own processor for required keys."""
if _is_extensions_typeddict(cls):
return cls.__required_keys__

Expand All @@ -613,12 +607,12 @@ def _required_keys(cls: type) -> set[str]:
if key in superclass_keys:
continue
annotation_type = own_annotations[key]

if is_annotated(annotation_type):
# If this is `Annotated`, we need to get the origin twice.
annotation_type = get_origin(annotation_type)

annotation_origin = get_origin(annotation_type)
if annotation_origin is Annotated:
annotation_args = get_args(annotation_type)
if annotation_args:
annotation_type = annotation_args[0]
annotation_origin = get_origin(annotation_type)

if annotation_origin is Required:
required_keys.add(key)
Expand Down
Loading