Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: draft solution for PEP 563 support #47

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 71 additions & 34 deletions dataconf/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import contextlib
import os
from typing import Any, Optional
Expand All @@ -19,11 +20,40 @@
YAML = 2


def inject_callee_scope(func):
def inner(*args, **kwargs):
noglobals = "globalns" not in kwargs
nolocals = "localns" not in kwargs

if noglobals or nolocals:
frame = inspect.stack()[1][0]

if noglobals:
kwargs["globalns"] = frame.f_globals
if nolocals:
kwargs["localns"] = frame.f_locals

return func(
*args,
**kwargs,
)

return inner


@inject_callee_scope
def parse(
conf: ConfigTree, clazz, strict: bool = True, ignore_unexpected: bool = False
conf: ConfigTree,
clazz,
strict: bool = True,
ignore_unexpected: bool = False,
globalns=None,
localns=None,
):
try:
return utils.__parse(conf, clazz, "", strict, ignore_unexpected)
return utils.__parse(
conf, clazz, "", strict, ignore_unexpected, globalns, localns
)
except pyparsing.ParseSyntaxException as e:
raise MalformedConfigException(
f'parsing failure line {e.lineno} character {e.col}, got "{e.line}"'
Expand All @@ -39,93 +69,100 @@ def cli_parse(*args, **kwargs):


class Multi:
def __init__(self, confs: List[ConfigTree], strict: bool = True, **kwargs) -> None:
def __init__(self, confs: List[ConfigTree], strict: bool = True) -> None:
self.confs = confs
self.strict = strict
self.kwargs = kwargs

def env(self, prefix: str, **kwargs) -> "Multi":
self.strict = False
def env(self, prefix: str) -> "Multi":
data = env_vars_parse(prefix, os.environ)
return self.dict(data, **kwargs)
return Multi(self.confs, strict=False).dict(data)

def dict(self, obj: Dict[str, Any], **kwargs) -> "Multi":
def dict(self, obj: str) -> "Multi":
conf = ConfigFactory.from_dict(obj)
return Multi(self.confs + [conf], self.strict, **kwargs)
return Multi(self.confs + [conf], self.strict)

def string(self, s: str, loader: str = HOCON, **kwargs) -> "Multi":
def string(self, s: str, loader: str = HOCON) -> "Multi":
if loader == YAML:
data = safe_load(s)
return self.dict(data, **kwargs)
return self.dict(data)

conf = ConfigFactory.parse_string(s)
return Multi(self.confs + [conf], self.strict, **kwargs)
return Multi(self.confs + [conf], self.strict)

def url(self, uri: str, timeout: int = 10, **kwargs) -> "Multi":
def url(self, uri: str, timeout: int = 10) -> "Multi":
path = urlparse(uri).path
if path.endswith(".yaml") or path.endswith(".yml"):
with contextlib.closing(urlopen(uri, timeout=timeout)) as fd:
s = fd.read().decode("utf-8")
return self.string(s, loader=YAML, **kwargs)
return self.string(s, loader=YAML)

conf = ConfigFactory.parse_URL(uri, timeout=timeout, required=True)
return Multi(self.confs + [conf], self.strict, **kwargs)
return Multi(self.confs + [conf], self.strict)

def file(self, path: str, loader: Optional[str] = None, **kwargs) -> "Multi":
def file(self, path: str, loader: Optional[str] = None) -> "Multi":
if loader == YAML or (
loader is None and (path.endswith(".yaml") or path.endswith(".yml"))
):
with open(path, "r") as f:
data = safe_load(f)
return self.dict(data, **kwargs)
return self.dict(data)

conf = ConfigFactory.parse_file(path)
return Multi(self.confs + [conf], self.strict, **kwargs)
return Multi(self.confs + [conf], self.strict)

def cli(self, argv: List[str], **kwargs) -> "Multi":
def cli(self, argv: List[str]) -> "Multi":
data = cli_parse(argv)
return self.dict(data, **kwargs)
return self.dict(data)

def on(self, clazz: Type):
@inject_callee_scope
def on(self, clazz: Type, **kwargs):
conf, *nxts = self.confs
for nxt in nxts:
conf = ConfigTree.merge_configs(conf, nxt)
return parse(conf, clazz, self.strict, **self.kwargs)
return parse(conf, clazz, self.strict, **kwargs)


multi = Multi([])


@inject_callee_scope
def env(prefix: str, clazz: Type, **kwargs):
return multi.env(prefix, **kwargs).on(clazz)
return multi.env(prefix).on(clazz, **kwargs)


@inject_callee_scope
def dict(obj: Dict[str, Any], clazz: Type, **kwargs):
return multi.dict(obj, **kwargs).on(clazz)
return multi.dict(obj).on(clazz, **kwargs)


def string(s: str, clazz: Type, **kwargs):
return multi.string(s, **kwargs).on(clazz)
@inject_callee_scope
def string(s: str, clazz: Type, loader: str = HOCON, **kwargs):
return multi.string(s, loader).on(clazz, **kwargs)


@inject_callee_scope
def url(uri: str, clazz: Type, **kwargs):
return multi.url(uri, **kwargs).on(clazz)
return multi.url(uri).on(clazz, **kwargs)


def file(path: str, clazz: Type, **kwargs):
return multi.file(path, **kwargs).on(clazz)
@inject_callee_scope
def file(path: str, clazz: Type, loader: Optional[str] = None, **kwargs):
return multi.file(path, loader).on(clazz, **kwargs)


@inject_callee_scope
def cli(argv: List[str], clazz: Type, **kwargs):
return multi.cli(argv, **kwargs).on(clazz)
return multi.cli(argv).on(clazz, **kwargs)


def load(path: str, clazz: Type, **kwargs):
return file(path, clazz, **kwargs)
@inject_callee_scope
def load(path: str, clazz: Type, loader: Optional[str] = None, **kwargs):
return file(path, clazz, loader, **kwargs)


def loads(s: str, clazz: Type, **kwargs):
return string(s, clazz, **kwargs)
@inject_callee_scope
def loads(s: str, clazz: Type, loader: str = HOCON, **kwargs):
return string(s, clazz, loader, **kwargs)


def dump(file: str, instance: object, out: str):
Expand Down
70 changes: 59 additions & 11 deletions dataconf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Dict
from typing import get_args
from typing import get_origin
from typing import get_type_hints
from typing import List
from typing import Type
from typing import Union
Expand Down Expand Up @@ -58,7 +59,15 @@ def is_optional(type: Type):
return is_union(get_origin(type)) and NoneType in get_args(type)


def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected: bool):
def __parse(
value: any,
clazz: Type,
path: str,
strict: bool,
ignore_unexpected: bool,
globalns,
localns,
):
if is_dataclass(clazz):
if not isinstance(value, ConfigTree):
raise TypeConfigException(
Expand All @@ -68,6 +77,7 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
fs = {}
renamings = dict()

type_hints = get_type_hints(clazz, globalns, localns)
for f in fields(clazz):
if f.name in value:
val = value[f.name]
Expand All @@ -85,10 +95,16 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:

if not isinstance(val, _MISSING_TYPE):
fs[f.name] = __parse(
val, f.type, f"{path}.{f.name}", strict, ignore_unexpected
val,
type_hints[f.name],
f"{path}.{f.name}",
strict,
ignore_unexpected,
globalns,
localns,
)

elif is_optional(f.type):
elif is_optional(type_hints[f.name]):
# Optional not found
fs[f.name] = None

Expand All @@ -115,7 +131,15 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
if value is not None:
parse_candidate = args[0]
return [
__parse(v, parse_candidate, f"{path}[]", strict, ignore_unexpected)
__parse(
v,
args[0],
f"{path}[]",
strict,
ignore_unexpected,
globalns,
localns,
)
for v in value
]
return None
Expand All @@ -129,7 +153,15 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
# ignore key type
parse_candidate = args[1]
return {
k: __parse(v, parse_candidate, f"{path}.{k}", strict, ignore_unexpected)
k: __parse(
v,
args[1],
f"{path}.{k}",
strict,
ignore_unexpected,
globalns,
localns,
)
for k, v in value.items()
}
return None
Expand All @@ -143,7 +175,13 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
else:
try:
return __parse(
value, parse_candidate, path, strict, ignore_unexpected
value,
parse_candidate,
path,
strict,
ignore_unexpected,
globalns,
localns,
)
except TypeConfigException:
continue
Expand All @@ -156,23 +194,25 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
)

if clazz is bool:
if not strict and value is not None:
if not strict and isinstance(value, str):
try:
value = bool(value)
except ValueError:
pass
return __parse_type(value, clazz, path, isinstance(value, bool))

if clazz is int:
if not strict and value is not None:
if not strict and isinstance(value, str):
try:
value = int(value)
cast = int(value)
if float(cast) == float(value):
value = cast
except ValueError:
pass
return __parse_type(value, clazz, path, isinstance(value, int))

if clazz is float:
if not strict and value is not None:
if not strict and isinstance(value, str):
try:
value = float(value)
except ValueError:
Expand Down Expand Up @@ -224,7 +264,15 @@ def __parse(value: any, clazz: Type, path: str, strict: bool, ignore_unexpected:
child_successes.append(
(
child_clazz,
__parse(value, child_clazz, path, strict, ignore_unexpected),
__parse(
value,
child_clazz,
path,
strict,
ignore_unexpected,
globalns,
localns,
),
)
)
except (
Expand Down
35 changes: 35 additions & 0 deletions tests/test_futur_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from dataclasses import dataclass
import os
from typing import get_type_hints
from typing import Text

import dataconf
from dataconf.main import inject_callee_scope


@inject_callee_scope
def out_of_scope_assert(clazz, expected, globalns, localns):
assert get_type_hints(clazz, globalns, localns)["a"] is expected


class TestFuturAnnotations:
def test_43(self) -> None:
@dataclass
class Model:
token: str

os.environ["TEST_token"] = "1"
dataconf.env("TEST_", Model)

def test_repro(self) -> None:
@dataclass
class A:
value: Text

@dataclass
class B:
a: A

out_of_scope_assert(B, A, globalns={})