Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace_subgroups function #215

Merged
merged 31 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
667affb
add feature replace_subgroups
zhiruiluo Feb 3, 2023
61c9b5f
add test
zhiruiluo Feb 3, 2023
6f67d8f
merge master
zhiruiluo Feb 4, 2023
70dbe42
fix replace_subgroups
zhiruiluo Feb 4, 2023
3520cd7
add comment
zhiruiluo Feb 4, 2023
f42b0ad
reduce if else
zhiruiluo Feb 4, 2023
49dc575
update replace_subgroups
zhiruiluo Feb 5, 2023
12bae9a
rename replace_selections
zhiruiluo Feb 5, 2023
c0e88b2
fix doctest unflatten_keyword
zhiruiluo Feb 5, 2023
e437433
add doctest for replace_selections
zhiruiluo Feb 5, 2023
d90fe36
fix docstring unflatten_keyword
zhiruiluo Feb 5, 2023
cd38e0c
add test for invalid case
zhiruiluo Feb 5, 2023
b7b6405
modify docstring
zhiruiluo Feb 5, 2023
12c1d2a
Merge branch 'master' into add_replace_subgroups
zhiruiluo Mar 27, 2023
14cf1a0
keep only replace_subgroups function and move it into replace.py
zhiruiluo Mar 27, 2023
25eb328
remove @overload for replace_subgroups
zhiruiluo Mar 27, 2023
dd3d2ed
apply pre-commit check
zhiruiluo Mar 27, 2023
1acfe43
Merge branch 'master' into add_replace_subgroups
zhiruiluo Apr 3, 2023
c71c606
Update simple_parsing/replace.py
zhiruiluo Apr 5, 2023
2477fec
Update simple_parsing/replace.py
zhiruiluo Apr 5, 2023
c9a68eb
address some issues
zhiruiluo Apr 5, 2023
5973d05
rename some variables
zhiruiluo Apr 5, 2023
37814fa
rm replace_subgroups.py
zhiruiluo Apr 5, 2023
74b52a2
Fix unflatten_selection_dict
zhiruiluo Apr 6, 2023
d0b970e
Merge branch 'master' into add_replace_subgroups
zhiruiluo Apr 19, 2023
22930d8
Merge branch 'master' into add_replace_subgroups
zhiruiluo Jun 1, 2023
4b24983
change as per suggestion
zhiruiluo Jun 1, 2023
7f6ead6
Merge branch 'master' into add_replace_subgroups
zhiruiluo Jun 5, 2023
591a4a7
Merge branch 'master' into add_replace_subgroups
zhiruiluo Jul 1, 2023
d22d9cd
Merge branch 'master' into add_replace_subgroups
zhiruiluo Jul 10, 2023
41e9419
Merge branch 'master' into add_replace_subgroups
lebrice Jul 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion simple_parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parse,
parse_known_args,
)
from .replace import replace
from .replace import replace, replace_subgroups
from .utils import InconsistentArgumentError

__all__ = [
Expand All @@ -49,6 +49,7 @@
"ParsingError",
"Partial",
"replace",
"replace_subgroups",
"Serializable",
"SimpleHelpFormatter",
"subgroups",
Expand Down
147 changes: 144 additions & 3 deletions simple_parsing/replace.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
from __future__ import annotations

import copy
import dataclasses
from typing import Any, overload

from simple_parsing.utils import DataclassT, is_dataclass_instance, unflatten_split
import logging
from typing import Any, Mapping, overload

from simple_parsing.annotation_utils.get_field_annotations import (
get_field_type_from_annotations,
)
from simple_parsing.helpers.subgroups import Key
from simple_parsing.utils import (
DataclassT,
PossiblyNestedDict,
V,
contains_dataclass_type_arg,
is_dataclass_instance,
is_dataclass_type,
is_optional,
unflatten_split,
)

logger = logging.getLogger(__name__)


@overload
Expand Down Expand Up @@ -90,3 +107,127 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang
replace_kwargs.update(changes)

return dataclasses.replace(obj, **replace_kwargs)


def replace_subgroups(
obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None
) -> DataclassT:
"""
This function replaces the dataclass of subgroups, union, and optional union.
The `selections` dict can be in flat format or in nested format.

The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance.
"""
keyword = "__key__"

if not selections:
return obj
selections = _unflatten_selection_dict(selections, keyword, recursive=False)

replace_kwargs = {}
for field in dataclasses.fields(obj):
if not field.init:
raise ValueError(f"Cannot replace value of non-init field {field.name}.")

if field.name not in selections:
continue

field_value = getattr(obj, field.name)
field_annotation = get_field_type_from_annotations(obj.__class__, field.name)

new_value = None
# Replace subgroup is allowed when the type annotation contains dataclass
if not contains_dataclass_type_arg(field_annotation):
raise ValueError(
f"The replaced subgroups contains no dataclass in its annotation {field_annotation}"
)

selection = selections.pop(field.name)
if isinstance(selection, dict):
value_of_selection = selection.pop(keyword, None)
child_selections = selection
else:
value_of_selection = selection
child_selections = None

if is_dataclass_type(value_of_selection):
field_value = value_of_selection()
elif is_dataclass_instance(value_of_selection):
field_value = copy.deepcopy(value_of_selection)
elif field.metadata.get("subgroups", None):
assert isinstance(value_of_selection, str)
subgroup_selection = field.metadata["subgroups"][value_of_selection]
if is_dataclass_instance(subgroup_selection):
# when the subgroup selection is a frozen dataclass instance
field_value = subgroup_selection
else:
# when the subgroup selection is a dataclass type
field_value = field.metadata["subgroups"][value_of_selection]()
elif is_optional(field_annotation) and value_of_selection is None:
field_value = None
elif contains_dataclass_type_arg(field_annotation) and value_of_selection is None:
field_value = field.default_factory()
else:
raise ValueError(
f"invalid selection key '{value_of_selection}' for field '{field.name}'"
)

if child_selections:
new_value = replace_subgroups(field_value, child_selections)
else:
new_value = field_value

replace_kwargs[field.name] = new_value
return dataclasses.replace(obj, **replace_kwargs)


def _unflatten_selection_dict(
flattened: Mapping[str, V], keyword: str = "__key__", sep: str = ".", recursive: bool = True
) -> PossiblyNestedDict[str, V]:
"""
This function convert a flattened dict into a nested dict
and it inserts the `keyword` as the selection into the nested dict.

>>> _unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'})
{'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}}

>>> _unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'})
{'lv1': {'__key__': 'a', 'lv2': {'__key__': 'b', 'lv3': 'c'}}}

>>> _unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}, recursive=False)
{'lv1': {'__key__': 'a', 'lv2': 'b', 'lv2.lv3': 'c'}}

>>> _unflatten_selection_dict({'ab_or_cd.c_or_d': 'd'})
{'ab_or_cd': {'c_or_d': 'd'}}

>>> _unflatten_selection_dict({"a": 1, "b": 2})
{'a': 1, 'b': 2}
"""
dc = {}

unflatten_those_top_level_keys = set()
for k, v in flattened.items():
splited_keys = k.split(sep)
if len(splited_keys) >= 2:
unflatten_those_top_level_keys.add(splited_keys[0])

for k, v in flattened.items():
keys = k.split(sep)
top_level_key = keys[0]
rest_keys = keys[1:]
if top_level_key in unflatten_those_top_level_keys:
sub_dc = dc.get(top_level_key, {})
if len(rest_keys) == 0:
sub_dc[keyword] = v
else:
sub_dc[".".join(rest_keys)] = v
dc[top_level_key] = sub_dc
else:
dc[k] = v

if recursive:
for k in unflatten_those_top_level_keys:
v = dc.pop(k)
unflatten_v = _unflatten_selection_dict(v, recursive=recursive)
dc[k] = unflatten_v
return dc
55 changes: 55 additions & 0 deletions test/test_replace_subgroups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from dataclasses import dataclass, field

from simple_parsing import replace_subgroups, subgroups


@dataclass
class A:
a: float = 0.0


@dataclass
class B:
b: str = "bar"


@dataclass
class AorB:
a_or_b: A | B = subgroups({"a": A, "b": B}, default_factory=A)


@dataclass(frozen=True)
class FrozenConfig:
a: int = 1
b: str = "bob"


odd = FrozenConfig(a=1, b="odd")
even = FrozenConfig(a=2, b="even")


@dataclass
class Config:
subgroup: A | B = subgroups({"a": A, "b": B}, default_factory=A)
frozen_subgroup: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd)
optional: A | None = None
implicit_optional: A = None
union: A | B = field(default_factory=A)
nested_subgroup: AorB = field(default_factory=AorB)


def test_replace_subgroups():
c = Config()
assert replace_subgroups(c, {"subgroup": "b"}) == Config(subgroup=B())
assert replace_subgroups(c, {"frozen_subgroup": "odd"}) == Config(frozen_subgroup=odd)
assert replace_subgroups(c, {"optional": A}) == Config(optional=A())
assert replace_subgroups(c, {"implicit_optional": A}) == Config(implicit_optional=A())
assert replace_subgroups(c, {"union": B}) == Config(union=B())
assert replace_subgroups(c, {"nested_subgroup.a_or_b": "b"}) == Config(
nested_subgroup=AorB(a_or_b=B())
)
assert replace_subgroups(c, {"nested_subgroup": {"a_or_b": "b"}}) == Config(
nested_subgroup=AorB(a_or_b=B())
)