Skip to content

Commit

Permalink
Fix annotation evaluation on inherited dataclasses
Browse files Browse the repository at this point in the history
The `_get_fields()` function evaluated annotations in the context of
the given dataclass, but each annotation should be evaluated in the
context of its surrounding class.
  • Loading branch information
mthuurne committed Jun 27, 2023
1 parent 3f54912 commit 0558c3d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def _get_fields(cls: type) -> Iterator[tuple[str, type]]:

fields_by_name = {field.name: field for field in fields(cls)}

# Note: getmodule() can return None, but the end result is still fine.
cls_globals = getattr(getmodule(cls), "__dict__", {})
cls_locals = vars(cls)

for field_container in reversed(cls.__mro__):
# Note: getmodule() can return None, but the end result is still fine.
cls_globals = getattr(getmodule(field_container), "__dict__", {})
cls_locals = vars(field_container)

for name, annotation in get_annotations(field_container).items():
field = fields_by_name[name]
if not field.init:
Expand Down
5 changes: 4 additions & 1 deletion tests/example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TypeAlias

URL: TypeAlias = str


@dataclass(frozen=True)
class Config:
"""Configuration for an example service."""

database_url: str
database_url: URL
"""The URL of the database to connect to."""

port: int = 12345
Expand Down
12 changes: 6 additions & 6 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,21 @@ def test_bind_inheritance() -> None:
"""A dataclass inheriting from another dataclass accepts fields from both the base and the subclass."""

@dataclass(frozen=True)
class ExtendedConfig(Config):
class ExtendedConfig(example.Config):
"""Inheriting from a class in another module complicates the annotation evaluation."""

dry_run: bool = False

with stream_text(
"""
rest-api-port = 6000
feed-job-prefixes = ["MIX1:", "MIX2:", "MIX3:"]
database-url = "postgresql://smaug:gold@mountain/hoard"
dry-run = true
"""
) as stream:
config = Binder(ExtendedConfig).parse_toml(stream)

assert config.rest_api_port == 6000
assert config.feed_job_prefixes == ("MIX1:", "MIX2:", "MIX3:")
assert config.import_max_nr_hours == 24
assert config.database_url == "postgresql://smaug:gold@mountain/hoard"
assert config.port == 12345
assert config.dry_run is True


Expand Down

0 comments on commit 0558c3d

Please sign in to comment.