From 667affbbff655cecd412f0818a48b09d2a462cde Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Thu, 2 Feb 2023 17:16:30 -0700 Subject: [PATCH 01/22] add feature replace_subgroups --- simple_parsing/__init__.py | 2 ++ simple_parsing/replace_subgroups.py | 47 +++++++++++++++++++++++++++++ test/test_replace_subgroups.py | 39 ++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 simple_parsing/replace_subgroups.py create mode 100644 test/test_replace_subgroups.py diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index b12be30a..6b406cdb 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -24,6 +24,7 @@ parse, parse_known_args, ) +from .replace_subgroups import replace_subgroups from .utils import InconsistentArgumentError __all__ = [ @@ -43,6 +44,7 @@ "parse_known_args", "parse", "ParsingError", + "replace_subgroups", "Serializable", "SimpleHelpFormatter", "subgroups", diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py new file mode 100644 index 00000000..302a574e --- /dev/null +++ b/simple_parsing/replace_subgroups.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import dataclasses +from simple_parsing.utils import is_dataclass_instance, DataclassT, unflatten_split +from typing import Any, overload +from simple_parsing.helpers.subgroups import Key + + +@overload +def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any], subgroup_changes: dict[str, Key] | None = None) -> DataclassT: + ... + + +@overload +def replace_subgroups(obj: DataclassT, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: + ... + + +def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = None, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: + if changes_dict and changes: + raise ValueError("Cannot pass both `changes_dict` and `changes`") + changes = changes_dict or changes + # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. + # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) + changes = unflatten_split(changes) + + replace_kwargs = {} + for field in dataclasses.fields(obj): + if field.name not in changes: + continue + if not field.init: + raise ValueError( + f"Cannot replace value of non-init field {field.name}.") + + field_value = getattr(obj, field.name) + + if is_dataclass_instance(field_value): + if field.metadata.get('subgroups', None) and subgroup_changes is not None: + field_value = field.metadata['subgroups'][subgroup_changes[field.name]]() + field_changes = changes[field.name] + new_value = replace_subgroups(field_value, field_changes, subgroup_changes) + elif isinstance(changes[field.name], dict): + field_changes = changes[field.name] + new_value = replace_subgroups(field_value, field_changes, subgroup_changes) + else: + new_value = changes[field.name] + replace_kwargs[field.name] = new_value + return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py new file mode 100644 index 00000000..2758f673 --- /dev/null +++ b/test/test_replace_subgroups.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import functools +import logging +from dataclasses import dataclass, field +from typing import Union + +import pytest + +from simple_parsing import replace_subgroups, subgroups + +logger = logging.getLogger(__name__) + + +@dataclass +class A: + a: float = 0.0 + + +@dataclass +class B: + b: str = "bar" + b_post_init: str = field(init=False) + + def __post_init__(self): + self.b_post_init = self.b + "_post" + + +@dataclass +class AB: + a_or_b: A | B = subgroups({"A": A, "B": B},default_factory=A) + + +def test_replace_subgroups(): + src_config = AB() + dest_config = AB(a_or_b=B(b='bob')) + + assert replace_subgroups(src_config, {'a_or_b.b': "bob"}, subgroup_changes={'a_or_b': 'B'}) == dest_config + \ No newline at end of file From 61c9b5ff809751352ed81f42a743d200febfa9b6 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Thu, 2 Feb 2023 21:04:54 -0700 Subject: [PATCH 02/22] add test --- simple_parsing/replace_subgroups.py | 9 ++-- test/test_replace_subgroups.py | 82 +++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py index 302a574e..2d99d94b 100644 --- a/simple_parsing/replace_subgroups.py +++ b/simple_parsing/replace_subgroups.py @@ -22,6 +22,8 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = Non # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) changes = unflatten_split(changes) + if subgroup_changes: + subgroup_changes = unflatten_split(subgroup_changes) replace_kwargs = {} for field in dataclasses.fields(obj): @@ -36,11 +38,12 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = Non if is_dataclass_instance(field_value): if field.metadata.get('subgroups', None) and subgroup_changes is not None: field_value = field.metadata['subgroups'][subgroup_changes[field.name]]() + + if isinstance(changes[field.name], dict): field_changes = changes[field.name] new_value = replace_subgroups(field_value, field_changes, subgroup_changes) - elif isinstance(changes[field.name], dict): - field_changes = changes[field.name] - new_value = replace_subgroups(field_value, field_changes, subgroup_changes) + else: + new_value = changes[field.name] else: new_value = changes[field.name] replace_kwargs[field.name] = new_value diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index 2758f673..5a107b91 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -13,12 +13,12 @@ @dataclass -class A: +class A(): a: float = 0.0 @dataclass -class B: +class B(): b: str = "bar" b_post_init: str = field(init=False) @@ -27,13 +27,75 @@ def __post_init__(self): @dataclass -class AB: - a_or_b: A | B = subgroups({"A": A, "B": B},default_factory=A) - +class AB(): + integer_only_by_post_init: int = field(init=False) + integer_in_string: str = "1" + a_or_b: A | B = subgroups( + { + "a": A, + "a_1.23": functools.partial(A, a=1.23), + "b": B, + "b_bob": functools.partial(B, b="bob"), + }, + default="a", + ) + + def __post_init__(self): + self.integer_only_by_post_init = int(self.integer_in_string) + + +@dataclass +class C: + c: bool = False + + +@dataclass +class D: + d: int = 0 + + +@dataclass +class CD: + c_or_d: C | D = subgroups({"c": C, "d": D}, default="c") + + other_arg: str = "bob" + + +@dataclass +class NestedSubgroupsConfig: + ab_or_cd: AB | CD = subgroups( + {"ab": AB, "cd": CD}, + default_factory=AB, + ) -def test_replace_subgroups(): - src_config = AB() - dest_config = AB(a_or_b=B(b='bob')) - assert replace_subgroups(src_config, {'a_or_b.b': "bob"}, subgroup_changes={'a_or_b': 'B'}) == dest_config - \ No newline at end of file +@pytest.mark.parametrize( + ("dest_config", "src_config", "changes_dict", "subgroup_changes"), + [ + (AB(a_or_b=A(a=1.0)), AB(), {"a_or_b": {"a": 1.0}}, None), + (AB(a_or_b=B(b="foo")), AB(a_or_b=B()), {"a_or_b": {"b": "foo"}}, None), + (AB(a_or_b=B(b="bob")), AB(), {"a_or_b": B(b="bob")}, None), + ( + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), + NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), + {"ab_or_cd": {"integer_in_string": "2", "a_or_b": {"b": "bob"}}}, + None + ), + ( + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), + NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), + {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, + None + ), + ( + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), + NestedSubgroupsConfig(), + {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, + {"ab_or_cd": 'ab', "ab_or_cd.a_or_b": 'b'}, + ), + ], +) +def test_replace_nested_subgroups(dest_config: object, src_config: object, changes_dict: dict, subgroup_changes: dict): + config_replaced = replace_subgroups(src_config, changes_dict, subgroup_changes) + assert config_replaced == dest_config + From 70dbe42446c362d6ee68bfd3fec9a800b04f3795 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 00:28:50 -0700 Subject: [PATCH 03/22] fix replace_subgroups --- simple_parsing/replace_subgroups.py | 68 +++++++++++++++++++-------- test/test_replace_subgroups.py | 73 ++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 35 deletions(-) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py index 2d99d94b..42223357 100644 --- a/simple_parsing/replace_subgroups.py +++ b/simple_parsing/replace_subgroups.py @@ -5,6 +5,17 @@ from simple_parsing.helpers.subgroups import Key +def unflatten_subgroup_changes(subgroup_changes: dict[str, Key]): + dc = {} + for k, v in subgroup_changes.items(): + if '__key__' != k and '.' not in k: + dc[k+'.__key__'] = v + else: + dc[k] = v + + return unflatten_split(dc) + + @overload def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any], subgroup_changes: dict[str, Key] | None = None) -> DataclassT: ... @@ -22,29 +33,48 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = Non # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) changes = unflatten_split(changes) - if subgroup_changes: - subgroup_changes = unflatten_split(subgroup_changes) + if subgroup_changes: + subgroup_changes = unflatten_subgroup_changes(subgroup_changes) + replace_kwargs = {} for field in dataclasses.fields(obj): - if field.name not in changes: - continue - if not field.init: - raise ValueError( - f"Cannot replace value of non-init field {field.name}.") + if field.name in changes: + if not field.init: + raise ValueError(f"Cannot replace value of non-init field {field.name}.") - field_value = getattr(obj, field.name) + field_value = getattr(obj, field.name) - if is_dataclass_instance(field_value): - if field.metadata.get('subgroups', None) and subgroup_changes is not None: - field_value = field.metadata['subgroups'][subgroup_changes[field.name]]() - - if isinstance(changes[field.name], dict): - field_changes = changes[field.name] - new_value = replace_subgroups(field_value, field_changes, subgroup_changes) + if is_dataclass_instance(field_value) and isinstance(changes[field.name], dict): + field_changes = changes.pop(field.name) + sub_subgroup_changes = subgroup_changes + if subgroup_changes and field.name in subgroup_changes: + sub_subgroup_changes = subgroup_changes.pop(field.name) + key = sub_subgroup_changes.pop("__key__", None) + if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: + field_value = field.metadata['subgroups'][key]() + new_value = replace_subgroups(field_value, field_changes, sub_subgroup_changes) + else: + new_value = changes.pop(field.name) + replace_kwargs[field.name] = new_value + elif subgroup_changes and field.name in subgroup_changes: + if not field.init: + raise ValueError(f"Cannot replace value of non-init field {field.name}.") + + sub_subgroup_changes = subgroup_changes.pop(field.name) + key = sub_subgroup_changes.pop("__key__", None) + if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: + field_value = field.metadata['subgroups'][key]() + new_value = replace_subgroups(field_value, None, sub_subgroup_changes) else: - new_value = changes[field.name] + field_value = getattr(obj, field.name) + new_value = replace_subgroups(field_value, None, sub_subgroup_changes) + replace_kwargs[field.name] = new_value else: - new_value = changes[field.name] - replace_kwargs[field.name] = new_value - return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file + continue + + # note: there may be some leftover values in `changes` that are not fields of this dataclass. + # we still pass those. + replace_kwargs.update(changes) + + return dataclasses.replace(obj, **replace_kwargs) diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index 5a107b91..9159d405 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -3,11 +3,11 @@ import functools import logging from dataclasses import dataclass, field -from typing import Union import pytest from simple_parsing import replace_subgroups, subgroups +from simple_parsing.utils import Dataclass, DataclassT logger = logging.getLogger(__name__) @@ -26,6 +26,11 @@ def __post_init__(self): self.b_post_init = self.b + "_post" +@dataclass +class WithOptional: + optional_a: A | None = None + + @dataclass class AB(): integer_only_by_post_init: int = field(init=False) @@ -67,35 +72,71 @@ class NestedSubgroupsConfig: {"ab": AB, "cd": CD}, default_factory=AB, ) - - + + @pytest.mark.parametrize( - ("dest_config", "src_config", "changes_dict", "subgroup_changes"), + ("start", "changes", "subgroup_changes", "expected"), [ - (AB(a_or_b=A(a=1.0)), AB(), {"a_or_b": {"a": 1.0}}, None), - (AB(a_or_b=B(b="foo")), AB(a_or_b=B()), {"a_or_b": {"b": "foo"}}, None), - (AB(a_or_b=B(b="bob")), AB(), {"a_or_b": B(b="bob")}, None), ( - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), + AB(), + {"a_or_b": {"a": 1.0}}, + None, + AB(a_or_b=A(a=1.0)), + ), + ( + AB(a_or_b=B()), + {"a_or_b": {"b": "foo"}}, + None, + AB(a_or_b=B(b="foo")) + ), + ( + AB(), + {"a_or_b": {"b": "foo"}}, + {"a_or_b": "b"}, + AB(a_or_b=B(b="foo")) + ), + ( + AB(), + {"a_or_b": B(b="bob")}, + None, + AB(a_or_b=B(b="bob")), + ), + ( NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), {"ab_or_cd": {"integer_in_string": "2", "a_or_b": {"b": "bob"}}}, - None + None, + NestedSubgroupsConfig(ab_or_cd=AB( + integer_in_string="2", a_or_b=B(b="bob"))), ), ( - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, - None + None, + NestedSubgroupsConfig(ab_or_cd=AB( + integer_in_string="2", a_or_b=B(b="bob"))), ), ( - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), NestedSubgroupsConfig(), {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, {"ab_or_cd": 'ab', "ab_or_cd.a_or_b": 'b'}, + NestedSubgroupsConfig(ab_or_cd=AB( + integer_in_string="2", a_or_b=B(b="bob"))), + ), + ( + NestedSubgroupsConfig(), + None, + {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, + NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())), + ), + ( + NestedSubgroupsConfig(), + {"ab_or_cd.c_or_d.d": 1}, + {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, + NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D(d=1))), ), ], ) -def test_replace_nested_subgroups(dest_config: object, src_config: object, changes_dict: dict, subgroup_changes: dict): - config_replaced = replace_subgroups(src_config, changes_dict, subgroup_changes) - assert config_replaced == dest_config - +def test_replace_subgroups(start: DataclassT, changes: dict, subgroup_changes: dict, expected: DataclassT): + actual = replace_subgroups( + start, changes, subgroup_changes) + assert actual == expected From 3520cd7ee2b49e2f169cccfe3f116cbf92131ae7 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 00:35:46 -0700 Subject: [PATCH 04/22] add comment --- simple_parsing/replace_subgroups.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py index 42223357..10e07675 100644 --- a/simple_parsing/replace_subgroups.py +++ b/simple_parsing/replace_subgroups.py @@ -6,6 +6,10 @@ def unflatten_subgroup_changes(subgroup_changes: dict[str, Key]): + """ + This function convert subgroup_changes = {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"} + into {"ab_or_cd": {"__key__": "cd", "c_or_d": {"__key__": "d"}}} + """ dc = {} for k, v in subgroup_changes.items(): if '__key__' != k and '.' not in k: @@ -35,6 +39,7 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = Non changes = unflatten_split(changes) if subgroup_changes: + # subgroup_changes is in a flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} subgroup_changes = unflatten_subgroup_changes(subgroup_changes) replace_kwargs = {} From f42b0adcb639d626ff6addff925c9dcce2a4c1f4 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 00:41:38 -0700 Subject: [PATCH 05/22] reduce if else --- simple_parsing/replace_subgroups.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py index 10e07675..ddade0aa 100644 --- a/simple_parsing/replace_subgroups.py +++ b/simple_parsing/replace_subgroups.py @@ -75,8 +75,6 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = Non field_value = getattr(obj, field.name) new_value = replace_subgroups(field_value, None, sub_subgroup_changes) replace_kwargs[field.name] = new_value - else: - continue # note: there may be some leftover values in `changes` that are not fields of this dataclass. # we still pass those. From 49dc575bae087f2cefa2c6c0a9813a3a4d030ac1 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 17:32:07 -0700 Subject: [PATCH 06/22] update replace_subgroups --- simple_parsing/replace_subgroups.py | 140 +++++++++++++++++++--------- test/test_replace_subgroups.py | 93 +++++++++++++++++- 2 files changed, 187 insertions(+), 46 deletions(-) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py index ddade0aa..10dfb52e 100644 --- a/simple_parsing/replace_subgroups.py +++ b/simple_parsing/replace_subgroups.py @@ -1,9 +1,15 @@ from __future__ import annotations import dataclasses -from simple_parsing.utils import is_dataclass_instance, DataclassT, unflatten_split -from typing import Any, overload +from simple_parsing.utils import is_dataclass_instance, DataclassT, unflatten_split, contains_dataclass_type_arg, is_dataclass_type, is_union, is_optional +from typing import Any, overload, Union, Tuple from simple_parsing.helpers.subgroups import Key +from enum import Enum +from simple_parsing.annotation_utils.get_field_annotations import get_field_type_from_annotations +from simple_parsing.replace import replace +import logging +import copy +logger = logging.getLogger(__name__) def unflatten_subgroup_changes(subgroup_changes: dict[str, Key]): """ @@ -29,55 +35,99 @@ def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any], subgroup_c def replace_subgroups(obj: DataclassT, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: ... + +def replace_union_dataclasses(obj: DataclassT, selections: dict[str, Key|DataclassT] | None = None): + if selections: + selections = unflatten_subgroup_changes(selections) + else: + return obj + + replace_kwargs = {} + for field in dataclasses.fields(obj): + if field.name not in selections: + continue + + field_value = getattr(obj, field.name) + t = get_field_type_from_annotations(obj.__class__, field.name) + if contains_dataclass_type_arg(t) and is_union(t): + child_selections = selections.pop(field.name) + key = child_selections.pop("__key__", None) + logger.debug(key) + if is_dataclass_type(key): + field_value = key() + logger.debug('is_dataclass_type') + elif is_dataclass_instance(key): + field_value = copy.deepcopy(key) + logger.debug('is_dataclass_instance') + elif field.metadata.get("subgroups", None): + field_value = field.metadata["subgroups"][key]() + logger.debug('is_subgroups') + elif is_optional(t) and key is None: + field_value = None + logger.debug("key is None") + else: + logger.debug('default') + if child_selections: + new_value = replace_union_dataclasses(field_value, child_selections) + else: + new_value = field_value + replace_kwargs[field.name] = new_value + + return dataclasses.replace(obj, **replace_kwargs) def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = None, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: - if changes_dict and changes: - raise ValueError("Cannot pass both `changes_dict` and `changes`") - changes = changes_dict or changes - # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. - # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) - changes = unflatten_split(changes) - if subgroup_changes: - # subgroup_changes is in a flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} - subgroup_changes = unflatten_subgroup_changes(subgroup_changes) + obj = replace_union_dataclasses(obj, subgroup_changes) + return replace(obj, changes_dict, **changes) + +# def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = None, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: +# if changes_dict and changes: +# raise ValueError("Cannot pass both `changes_dict` and `changes`") +# changes = changes_dict or changes +# # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. +# # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) +# changes = unflatten_split(changes) + +# if subgroup_changes: +# # subgroup_changes is in a flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} +# subgroup_changes = unflatten_subgroup_changes(subgroup_changes) - replace_kwargs = {} - for field in dataclasses.fields(obj): - if field.name in changes: - if not field.init: - raise ValueError(f"Cannot replace value of non-init field {field.name}.") +# replace_kwargs = {} +# for field in dataclasses.fields(obj): +# if field.name in changes: +# if not field.init: +# raise ValueError(f"Cannot replace value of non-init field {field.name}.") - field_value = getattr(obj, field.name) +# field_value = getattr(obj, field.name) - if is_dataclass_instance(field_value) and isinstance(changes[field.name], dict): - field_changes = changes.pop(field.name) - sub_subgroup_changes = subgroup_changes - if subgroup_changes and field.name in subgroup_changes: - sub_subgroup_changes = subgroup_changes.pop(field.name) - key = sub_subgroup_changes.pop("__key__", None) - if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: - field_value = field.metadata['subgroups'][key]() - new_value = replace_subgroups(field_value, field_changes, sub_subgroup_changes) - else: - new_value = changes.pop(field.name) - replace_kwargs[field.name] = new_value - elif subgroup_changes and field.name in subgroup_changes: - if not field.init: - raise ValueError(f"Cannot replace value of non-init field {field.name}.") +# if is_dataclass_instance(field_value) and isinstance(changes[field.name], dict): +# field_changes = changes.pop(field.name) +# sub_subgroup_changes = None +# if subgroup_changes and field.name in subgroup_changes: +# sub_subgroup_changes = subgroup_changes.pop(field.name) +# key = sub_subgroup_changes.pop("__key__", None) +# if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: +# field_value = field.metadata['subgroups'][key]() +# new_value = replace_subgroups(field_value, field_changes, sub_subgroup_changes) +# else: +# new_value = changes.pop(field.name) +# replace_kwargs[field.name] = new_value +# elif subgroup_changes and field.name in subgroup_changes: +# if not field.init: +# raise ValueError(f"Cannot replace value of non-init field {field.name}.") - sub_subgroup_changes = subgroup_changes.pop(field.name) - key = sub_subgroup_changes.pop("__key__", None) - if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: - field_value = field.metadata['subgroups'][key]() - new_value = replace_subgroups(field_value, None, sub_subgroup_changes) - else: - field_value = getattr(obj, field.name) - new_value = replace_subgroups(field_value, None, sub_subgroup_changes) - replace_kwargs[field.name] = new_value +# sub_subgroup_changes = subgroup_changes.pop(field.name) +# key = sub_subgroup_changes.pop("__key__", None) +# if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: +# field_value = field.metadata['subgroups'][key]() +# new_value = replace_subgroups(field_value, None, sub_subgroup_changes) +# else: +# field_value = getattr(obj, field.name) +# new_value = replace_subgroups(field_value, None, sub_subgroup_changes) +# replace_kwargs[field.name] = new_value - # note: there may be some leftover values in `changes` that are not fields of this dataclass. - # we still pass those. - replace_kwargs.update(changes) +# # note: there may be some leftover values in `changes` that are not fields of this dataclass. +# # we still pass those. +# replace_kwargs.update(changes) - return dataclasses.replace(obj, **replace_kwargs) +# return dataclasses.replace(obj, **replace_kwargs) diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index 9159d405..f6557296 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -2,12 +2,17 @@ import functools import logging -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields import pytest from simple_parsing import replace_subgroups, subgroups +from simple_parsing.replace_subgroups import replace_union_dataclasses from simple_parsing.utils import Dataclass, DataclassT +from enum import Enum +from pathlib import Path +from typing import Tuple +from simple_parsing.helpers.subgroups import Key logger = logging.getLogger(__name__) @@ -73,6 +78,32 @@ class NestedSubgroupsConfig: default_factory=AB, ) +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + +@dataclass +class AllTypes: + arg_int : int = 0 + arg_float: float = 1.0 + arg_str: str = 'foo' + arg_list: list = field(default_factory= lambda: [1,2]) + arg_dict : dict = field(default_factory=lambda : {"a":1 , "b": 2}) + arg_union: str| Path = './' + arg_tuple: Tuple[int, int] = (1,1) + arg_enum: Color = Color.BLUE + arg_dataclass: A = field(default_factory=A) + arg_subgroups: A | B = subgroups( + { + "a": A, + "a_1.23": functools.partial(A, a=1.23), + "b": B, + "b_bob": functools.partial(B, b="bob"), + }, + default="a", + ) + arg_optional: A | None = None @pytest.mark.parametrize( ("start", "changes", "subgroup_changes", "expected"), @@ -134,9 +165,69 @@ class NestedSubgroupsConfig: {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D(d=1))), ), + ( + AllTypes(), + {"arg_subgroups.b": "foo","arg_optional.a": 1.0}, + {"arg_subgroups": "b", "arg_optional": A}, + AllTypes(arg_subgroups=B(b="foo"), arg_optional=A(a=1.0)) + ), ], ) def test_replace_subgroups(start: DataclassT, changes: dict, subgroup_changes: dict, expected: DataclassT): actual = replace_subgroups( start, changes, subgroup_changes) assert actual == expected + + +@pytest.mark.parametrize( + ('start', 'changes', 'expected'), + [ + ( + AllTypes(), + {'arg_subgroups':'b',"arg_optional": A}, + AllTypes(arg_subgroups=B(), arg_optional=A()) + ), + ( + AllTypes(), + {'arg_subgroups': B,"arg_optional": A}, + AllTypes(arg_subgroups=B(), arg_optional=A()) + ), + ( + AllTypes(arg_optional=A()), + {'arg_subgroups': B,"arg_optional": None}, + AllTypes(arg_subgroups=B(), arg_optional=None) + ), + ( + AllTypes(arg_optional=A(a=1.0)), + {"arg_optional": A}, + AllTypes(arg_optional=A()) + ), + ( + AllTypes(arg_optional=None), + {"arg_optional": A(a=1.2)}, + AllTypes(arg_optional=A(a=1.2)) + ), + ( + AllTypes(arg_subgroups=A(a=1.0)), + {'arg_subgroups': 'a'}, + AllTypes(arg_subgroups=A()) + ), + ( + AllTypes(arg_subgroups=A(a=1.0)), + None, + AllTypes(arg_subgroups=A(a=1.0)) + ), + ( + AllTypes(arg_subgroups=A(a=1.0)), + {}, + AllTypes(arg_subgroups=A(a=1.0)) + ), + ( + NestedSubgroupsConfig(), + {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, + NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())) + ), + ] +) +def test_replace_union_dataclasses(start: DataclassT, changes:dict[str, Key|DataclassT], expected: DataclassT): + assert replace_union_dataclasses(start, changes) == expected \ No newline at end of file From 12bae9a06ab29c902d5530b7584ce766c45aa662 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 22:12:15 -0700 Subject: [PATCH 07/22] rename replace_selections --- simple_parsing/__init__.py | 4 +- simple_parsing/replace_selections.py | 92 ++++++++++++ simple_parsing/replace_subgroups.py | 133 ------------------ simple_parsing/utils.py | 24 ++++ ...ubgroups.py => test_replace_selections.py} | 29 ++-- 5 files changed, 136 insertions(+), 146 deletions(-) create mode 100644 simple_parsing/replace_selections.py delete mode 100644 simple_parsing/replace_subgroups.py rename test/{test_replace_subgroups.py => test_replace_selections.py} (88%) diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index f2b30db7..6559a038 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -24,7 +24,7 @@ parse, parse_known_args, ) -from .replace_subgroups import replace_subgroups +from .replace_selections import replace_selections from .replace import replace from .utils import InconsistentArgumentError @@ -45,7 +45,7 @@ "parse_known_args", "parse", "ParsingError", - "replace_subgroups", + "replace_selections", "replace", "Serializable", "SimpleHelpFormatter", diff --git a/simple_parsing/replace_selections.py b/simple_parsing/replace_selections.py new file mode 100644 index 00000000..ba664cb9 --- /dev/null +++ b/simple_parsing/replace_selections.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import copy +import dataclasses +import logging +from typing import Any, overload + +from simple_parsing.annotation_utils.get_field_annotations import \ + get_field_type_from_annotations +from simple_parsing.helpers.subgroups import Key +from simple_parsing.replace import replace +from simple_parsing.utils import (DataclassT, contains_dataclass_type_arg, + is_dataclass_instance, is_dataclass_type, + is_optional, is_union, unflatten_keyword) + +logger = logging.getLogger(__name__) + + +@overload +def replace_selections(obj: DataclassT, changes_dict: dict[str, Any], selections: dict[str, Key | DataclassT] | None = None) -> DataclassT: + ... + + +@overload +def replace_selections(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes) -> DataclassT: + ... + + +def replace_selected_dataclass(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None): + """ + This function replaces the dataclass of subgroups, union, and optional union. + The `selections` is in flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} + + The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. + """ + keyword = '__key__' + + if selections: + selections = unflatten_keyword(selections, keyword) + else: + return obj + + replace_kwargs = {} + for field in dataclasses.fields(obj): + if field.name not in selections: + continue + + field_value = getattr(obj, field.name) + t = get_field_type_from_annotations(obj.__class__, field.name) + + if contains_dataclass_type_arg(t) and is_union(t): + child_selections = selections.pop(field.name) + key = child_selections.pop(keyword, None) + + if is_dataclass_type(key): + field_value = key() + logger.debug('is_dataclass_type') + elif is_dataclass_instance(key): + field_value = copy.deepcopy(key) + logger.debug('is_dataclass_instance') + elif field.metadata.get("subgroups", None): + field_value = field.metadata["subgroups"][key]() + logger.debug('is_subgroups') + elif is_optional(t) and key is None: + field_value = None + logger.debug("key is None") + else: + logger.warn('Not Implemented') + raise TypeError(f"Not Implemented for field {field.name}!") + + if child_selections: + new_value = replace_selected_dataclass( + field_value, child_selections) + else: + new_value = field_value + + if not field.init: + raise ValueError( + f"Cannot replace value of non-init field {field.name}.") + + replace_kwargs[field.name] = new_value + return dataclasses.replace(obj, **replace_kwargs) + + +def replace_selections(obj: DataclassT, changes_dict: dict[str, Any] | None = None, selections: dict[str, Key | DataclassT] | None = None, **changes) -> DataclassT: + """Replace some values in a dataclass and replace dataclass type in nested union of dataclasses or subgroups. + + Compared to `simple_replace.replace`, this calls `replace_selected_dataclass` before calling `simple_parsing.replace`. + """ + if selections: + obj = replace_selected_dataclass(obj, selections) + return replace(obj, changes_dict, **changes) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py deleted file mode 100644 index 10dfb52e..00000000 --- a/simple_parsing/replace_subgroups.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import annotations -import dataclasses -from simple_parsing.utils import is_dataclass_instance, DataclassT, unflatten_split, contains_dataclass_type_arg, is_dataclass_type, is_union, is_optional -from typing import Any, overload, Union, Tuple -from simple_parsing.helpers.subgroups import Key -from enum import Enum -from simple_parsing.annotation_utils.get_field_annotations import get_field_type_from_annotations -from simple_parsing.replace import replace -import logging -import copy - -logger = logging.getLogger(__name__) - -def unflatten_subgroup_changes(subgroup_changes: dict[str, Key]): - """ - This function convert subgroup_changes = {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"} - into {"ab_or_cd": {"__key__": "cd", "c_or_d": {"__key__": "d"}}} - """ - dc = {} - for k, v in subgroup_changes.items(): - if '__key__' != k and '.' not in k: - dc[k+'.__key__'] = v - else: - dc[k] = v - - return unflatten_split(dc) - - -@overload -def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any], subgroup_changes: dict[str, Key] | None = None) -> DataclassT: - ... - - -@overload -def replace_subgroups(obj: DataclassT, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: - ... - - -def replace_union_dataclasses(obj: DataclassT, selections: dict[str, Key|DataclassT] | None = None): - if selections: - selections = unflatten_subgroup_changes(selections) - else: - return obj - - replace_kwargs = {} - for field in dataclasses.fields(obj): - if field.name not in selections: - continue - - field_value = getattr(obj, field.name) - t = get_field_type_from_annotations(obj.__class__, field.name) - if contains_dataclass_type_arg(t) and is_union(t): - child_selections = selections.pop(field.name) - key = child_selections.pop("__key__", None) - logger.debug(key) - if is_dataclass_type(key): - field_value = key() - logger.debug('is_dataclass_type') - elif is_dataclass_instance(key): - field_value = copy.deepcopy(key) - logger.debug('is_dataclass_instance') - elif field.metadata.get("subgroups", None): - field_value = field.metadata["subgroups"][key]() - logger.debug('is_subgroups') - elif is_optional(t) and key is None: - field_value = None - logger.debug("key is None") - else: - logger.debug('default') - if child_selections: - new_value = replace_union_dataclasses(field_value, child_selections) - else: - new_value = field_value - replace_kwargs[field.name] = new_value - - return dataclasses.replace(obj, **replace_kwargs) - -def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = None, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: - if subgroup_changes: - obj = replace_union_dataclasses(obj, subgroup_changes) - return replace(obj, changes_dict, **changes) - -# def replace_subgroups(obj: DataclassT, changes_dict: dict[str, Any] | None = None, subgroup_changes: dict[str, Key] | None = None, **changes) -> DataclassT: -# if changes_dict and changes: -# raise ValueError("Cannot pass both `changes_dict` and `changes`") -# changes = changes_dict or changes -# # changes can be given in a 'flat' format in `changes_dict`, e.g. {"a.b.c": 123}. -# # Unflatten them back to a nested dict (e.g. {"a": {"b": {"c": 123}}}) -# changes = unflatten_split(changes) - -# if subgroup_changes: -# # subgroup_changes is in a flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} -# subgroup_changes = unflatten_subgroup_changes(subgroup_changes) - -# replace_kwargs = {} -# for field in dataclasses.fields(obj): -# if field.name in changes: -# if not field.init: -# raise ValueError(f"Cannot replace value of non-init field {field.name}.") - -# field_value = getattr(obj, field.name) - -# if is_dataclass_instance(field_value) and isinstance(changes[field.name], dict): -# field_changes = changes.pop(field.name) -# sub_subgroup_changes = None -# if subgroup_changes and field.name in subgroup_changes: -# sub_subgroup_changes = subgroup_changes.pop(field.name) -# key = sub_subgroup_changes.pop("__key__", None) -# if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: -# field_value = field.metadata['subgroups'][key]() -# new_value = replace_subgroups(field_value, field_changes, sub_subgroup_changes) -# else: -# new_value = changes.pop(field.name) -# replace_kwargs[field.name] = new_value -# elif subgroup_changes and field.name in subgroup_changes: -# if not field.init: -# raise ValueError(f"Cannot replace value of non-init field {field.name}.") - -# sub_subgroup_changes = subgroup_changes.pop(field.name) -# key = sub_subgroup_changes.pop("__key__", None) -# if key and field.metadata.get('subgroups') and key in field.metadata['subgroups']: -# field_value = field.metadata['subgroups'][key]() -# new_value = replace_subgroups(field_value, None, sub_subgroup_changes) -# else: -# field_value = getattr(obj, field.name) -# new_value = replace_subgroups(field_value, None, sub_subgroup_changes) -# replace_kwargs[field.name] = new_value - -# # note: there may be some leftover values in `changes` that are not fields of this dataclass. -# # we still pass those. -# replace_kwargs.update(changes) - -# return dataclasses.replace(obj, **replace_kwargs) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 755194ba..d609b37b 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -938,6 +938,30 @@ def unflatten_split( return unflatten({tuple(key.split(sep)): value for key, value in flattened.items()}) +def unflatten_keyword( + flattened: Mapping[str, V], keyword: str = "__key__", sep='.' +) -> PossiblyNestedDict[str, V]: + """ + This function convert flattened = + into the nested dict + differentiating by the `keyword`. + + >>> unflatten_keyword({"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}) + {"ab_or_cd": {"__key__": "cd", "c_or_d": {"__key__": "d"}}} + + >>> unflatten_keyword({"a": 1, "b": 2}) + {"a": {"__key__": 1}, "b": {"__key__": 2}} + """ + dc = {} + for k, v in flattened.items(): + if keyword != k and sep not in k: + dc[k+sep+keyword] = v + else: + dc[k] = v + + return unflatten_split(dc) + + @overload def getitem_recursive(d: PossiblyNestedDict[K, V], keys: Iterable[K]) -> V: ... diff --git a/test/test_replace_subgroups.py b/test/test_replace_selections.py similarity index 88% rename from test/test_replace_subgroups.py rename to test/test_replace_selections.py index f6557296..7803b7ca 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_selections.py @@ -1,18 +1,19 @@ from __future__ import annotations +import copy import functools import logging from dataclasses import dataclass, field, fields - -import pytest - -from simple_parsing import replace_subgroups, subgroups -from simple_parsing.replace_subgroups import replace_union_dataclasses -from simple_parsing.utils import Dataclass, DataclassT from enum import Enum from pathlib import Path from typing import Tuple + +import pytest + +from simple_parsing import replace_selections, subgroups from simple_parsing.helpers.subgroups import Key +from simple_parsing.replace_selections import replace_selected_dataclass +from simple_parsing.utils import DataclassT logger = logging.getLogger(__name__) @@ -104,9 +105,15 @@ class AllTypes: default="a", ) arg_optional: A | None = None + arg_union_dataclass: A | B = field(default_factory=A) + arg_union_dataclass_init_false: A | B = field(init=False) + + def __post_init__(self): + self.arg_union_dataclass_init_false = copy.copy(self.arg_union_dataclass) + @pytest.mark.parametrize( - ("start", "changes", "subgroup_changes", "expected"), + ("start", "changes", "selections", "expected"), [ ( AB(), @@ -173,9 +180,9 @@ class AllTypes: ), ], ) -def test_replace_subgroups(start: DataclassT, changes: dict, subgroup_changes: dict, expected: DataclassT): - actual = replace_subgroups( - start, changes, subgroup_changes) +def test_replace_selections(start: DataclassT, changes: dict, selections: dict, expected: DataclassT): + actual = replace_selections( + start, changes, selections) assert actual == expected @@ -230,4 +237,4 @@ def test_replace_subgroups(start: DataclassT, changes: dict, subgroup_changes: d ] ) def test_replace_union_dataclasses(start: DataclassT, changes:dict[str, Key|DataclassT], expected: DataclassT): - assert replace_union_dataclasses(start, changes) == expected \ No newline at end of file + assert replace_selected_dataclass(start, changes) == expected \ No newline at end of file From c0e88b205944e0c7f7a24dedb1e8c48ba2b273a0 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 22:22:07 -0700 Subject: [PATCH 08/22] fix doctest unflatten_keyword --- simple_parsing/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index d609b37b..00e7ac7d 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -946,11 +946,14 @@ def unflatten_keyword( into the nested dict differentiating by the `keyword`. - >>> unflatten_keyword({"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}) - {"ab_or_cd": {"__key__": "cd", "c_or_d": {"__key__": "d"}}} + >>> unflatten_keyword({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} >>> unflatten_keyword({"a": 1, "b": 2}) - {"a": {"__key__": 1}, "b": {"__key__": 2}} + {'a': {'__key__': 1}, 'b': {'__key__': 2}} + + NOTE: This function expects the input to be flat. It does *not* unflatten nested dicts: + """ dc = {} for k, v in flattened.items(): From e4374331f3948d26c1d30ac8606c1985122b7204 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 22:57:17 -0700 Subject: [PATCH 09/22] add doctest for replace_selections --- simple_parsing/__init__.py | 4 +- simple_parsing/replace_selections.py | 80 ++++++++++++---- simple_parsing/utils.py | 17 ++-- test/test_replace_selections.py | 133 +++++++++++---------------- 4 files changed, 124 insertions(+), 110 deletions(-) diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index 6559a038..fb1c90ce 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -24,8 +24,8 @@ parse, parse_known_args, ) -from .replace_selections import replace_selections from .replace import replace +from .replace_selections import replace_selections from .utils import InconsistentArgumentError __all__ = [ @@ -45,8 +45,8 @@ "parse_known_args", "parse", "ParsingError", - "replace_selections", "replace", + "replace_selections", "Serializable", "SimpleHelpFormatter", "subgroups", diff --git a/simple_parsing/replace_selections.py b/simple_parsing/replace_selections.py index ba664cb9..c75535dc 100644 --- a/simple_parsing/replace_selections.py +++ b/simple_parsing/replace_selections.py @@ -5,35 +5,50 @@ import logging from typing import Any, overload -from simple_parsing.annotation_utils.get_field_annotations import \ - get_field_type_from_annotations +from simple_parsing.annotation_utils.get_field_annotations import ( + get_field_type_from_annotations, +) from simple_parsing.helpers.subgroups import Key from simple_parsing.replace import replace -from simple_parsing.utils import (DataclassT, contains_dataclass_type_arg, - is_dataclass_instance, is_dataclass_type, - is_optional, is_union, unflatten_keyword) +from simple_parsing.utils import ( + DataclassT, + contains_dataclass_type_arg, + is_dataclass_instance, + is_dataclass_type, + is_optional, + is_union, + unflatten_keyword, +) logger = logging.getLogger(__name__) @overload -def replace_selections(obj: DataclassT, changes_dict: dict[str, Any], selections: dict[str, Key | DataclassT] | None = None) -> DataclassT: +def replace_selections( + obj: DataclassT, + changes_dict: dict[str, Any], + selections: dict[str, Key | DataclassT] | None = None, +) -> DataclassT: ... @overload -def replace_selections(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes) -> DataclassT: +def replace_selections( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes +) -> DataclassT: ... -def replace_selected_dataclass(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None): +def replace_selected_dataclass( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None +): """ This function replaces the dataclass of subgroups, union, and optional union. The `selections` is in flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. """ - keyword = '__key__' + keyword = "__key__" if selections: selections = unflatten_keyword(selections, keyword) @@ -54,38 +69,67 @@ def replace_selected_dataclass(obj: DataclassT, selections: dict[str, Key | Data if is_dataclass_type(key): field_value = key() - logger.debug('is_dataclass_type') + logger.debug("is_dataclass_type") elif is_dataclass_instance(key): field_value = copy.deepcopy(key) - logger.debug('is_dataclass_instance') + logger.debug("is_dataclass_instance") elif field.metadata.get("subgroups", None): field_value = field.metadata["subgroups"][key]() - logger.debug('is_subgroups') + logger.debug("is_subgroups") elif is_optional(t) and key is None: field_value = None logger.debug("key is None") else: - logger.warn('Not Implemented') + logger.warn("Not Implemented") raise TypeError(f"Not Implemented for field {field.name}!") if child_selections: - new_value = replace_selected_dataclass( - field_value, child_selections) + new_value = replace_selected_dataclass(field_value, child_selections) else: new_value = field_value if not field.init: - raise ValueError( - f"Cannot replace value of non-init field {field.name}.") + raise ValueError(f"Cannot replace value of non-init field {field.name}.") replace_kwargs[field.name] = new_value return dataclasses.replace(obj, **replace_kwargs) -def replace_selections(obj: DataclassT, changes_dict: dict[str, Any] | None = None, selections: dict[str, Key | DataclassT] | None = None, **changes) -> DataclassT: +def replace_selections( + obj: DataclassT, + changes_dict: dict[str, Any] | None = None, + selections: dict[str, Key | DataclassT] | None = None, + **changes, +) -> DataclassT: """Replace some values in a dataclass and replace dataclass type in nested union of dataclasses or subgroups. Compared to `simple_replace.replace`, this calls `replace_selected_dataclass` before calling `simple_parsing.replace`. + + ## Examples + >>> import dataclasses + >>> from simple_parsing import replace_selections, subgroups + >>> from typing import Union + >>> @dataclasses.dataclass + ... class A: + ... a: int = 0 + >>> @dataclasses.dataclass + ... class B: + ... b: str = "b" + >>> @dataclasses.dataclass + ... class Config: + ... a_or_b: Union[A, B] = subgroups({'a': A, 'b': B}, default_factory=A) + ... a_or_b_union: Union[A, B] = dataclasses.field(default_factory=A) + ... a_optional: Union[A, None] = None + + >>> base_config = Config(a_or_b=A(a=1)) + >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": "b"}) + Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) + + >>> replace_selections(base_config, {"a_or_b_union.b": "bob"}, {"a_or_b_union": B}) + Config(a_or_b=A(a=1), a_or_b_union=B(b='bob'), a_optional=None) + + >>> replace_selections(base_config, {"a_optional.a": 2}, {"a_optional": A}) + Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=2)) """ if selections: obj = replace_selected_dataclass(obj, selections) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 00e7ac7d..46254d6d 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -939,29 +939,28 @@ def unflatten_split( def unflatten_keyword( - flattened: Mapping[str, V], keyword: str = "__key__", sep='.' + flattened: Mapping[str, V], keyword: str = "__key__", sep="." ) -> PossiblyNestedDict[str, V]: """ - This function convert flattened = - into the nested dict + This function convert flattened = + into the nested dict differentiating by the `keyword`. - + >>> unflatten_keyword({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} - + >>> unflatten_keyword({"a": 1, "b": 2}) {'a': {'__key__': 1}, 'b': {'__key__': 2}} - + NOTE: This function expects the input to be flat. It does *not* unflatten nested dicts: - """ dc = {} for k, v in flattened.items(): if keyword != k and sep not in k: - dc[k+sep+keyword] = v + dc[k + sep + keyword] = v else: dc[k] = v - + return unflatten_split(dc) diff --git a/test/test_replace_selections.py b/test/test_replace_selections.py index 7803b7ca..c5a7656a 100644 --- a/test/test_replace_selections.py +++ b/test/test_replace_selections.py @@ -3,10 +3,9 @@ import copy import functools import logging -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Tuple import pytest @@ -19,12 +18,12 @@ @dataclass -class A(): +class A: a: float = 0.0 @dataclass -class B(): +class B: b: str = "bar" b_post_init: str = field(init=False) @@ -38,7 +37,7 @@ class WithOptional: @dataclass -class AB(): +class AB: integer_only_by_post_init: int = field(init=False) integer_in_string: str = "1" a_or_b: A | B = subgroups( @@ -79,20 +78,22 @@ class NestedSubgroupsConfig: default_factory=AB, ) + class Color(Enum): RED = 1 GREEN = 2 BLUE = 3 + @dataclass class AllTypes: - arg_int : int = 0 + arg_int: int = 0 arg_float: float = 1.0 - arg_str: str = 'foo' - arg_list: list = field(default_factory= lambda: [1,2]) - arg_dict : dict = field(default_factory=lambda : {"a":1 , "b": 2}) - arg_union: str| Path = './' - arg_tuple: Tuple[int, int] = (1,1) + arg_str: str = "foo" + arg_list: list = field(default_factory=lambda: [1, 2]) + arg_dict: dict = field(default_factory=lambda: {"a": 1, "b": 2}) + arg_union: str | Path = "./" + arg_tuple: tuple[int, int] = (1, 1) arg_enum: Color = Color.BLUE arg_dataclass: A = field(default_factory=A) arg_subgroups: A | B = subgroups( @@ -107,10 +108,10 @@ class AllTypes: arg_optional: A | None = None arg_union_dataclass: A | B = field(default_factory=A) arg_union_dataclass_init_false: A | B = field(init=False) - + def __post_init__(self): self.arg_union_dataclass_init_false = copy.copy(self.arg_union_dataclass) - + @pytest.mark.parametrize( ("start", "changes", "selections", "expected"), @@ -119,122 +120,92 @@ def __post_init__(self): AB(), {"a_or_b": {"a": 1.0}}, None, - AB(a_or_b=A(a=1.0)), + AB(a_or_b=A(a=1.0)), ), - ( - AB(a_or_b=B()), - {"a_or_b": {"b": "foo"}}, - None, - AB(a_or_b=B(b="foo")) - ), - ( - AB(), - {"a_or_b": {"b": "foo"}}, - {"a_or_b": "b"}, - AB(a_or_b=B(b="foo")) - ), + (AB(a_or_b=B()), {"a_or_b": {"b": "foo"}}, None, AB(a_or_b=B(b="foo"))), + (AB(), {"a_or_b": {"b": "foo"}}, {"a_or_b": "b"}, AB(a_or_b=B(b="foo"))), ( AB(), {"a_or_b": B(b="bob")}, None, - AB(a_or_b=B(b="bob")), + AB(a_or_b=B(b="bob")), ), ( NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), {"ab_or_cd": {"integer_in_string": "2", "a_or_b": {"b": "bob"}}}, None, - NestedSubgroupsConfig(ab_or_cd=AB( - integer_in_string="2", a_or_b=B(b="bob"))), + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), ), ( NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, None, - NestedSubgroupsConfig(ab_or_cd=AB( - integer_in_string="2", a_or_b=B(b="bob"))), + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), ), ( NestedSubgroupsConfig(), {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, - {"ab_or_cd": 'ab', "ab_or_cd.a_or_b": 'b'}, - NestedSubgroupsConfig(ab_or_cd=AB( - integer_in_string="2", a_or_b=B(b="bob"))), + {"ab_or_cd": "ab", "ab_or_cd.a_or_b": "b"}, + NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), ), ( NestedSubgroupsConfig(), None, - {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, + {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())), ), ( NestedSubgroupsConfig(), {"ab_or_cd.c_or_d.d": 1}, - {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, + {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D(d=1))), ), ( AllTypes(), - {"arg_subgroups.b": "foo","arg_optional.a": 1.0}, + {"arg_subgroups.b": "foo", "arg_optional.a": 1.0}, {"arg_subgroups": "b", "arg_optional": A}, - AllTypes(arg_subgroups=B(b="foo"), arg_optional=A(a=1.0)) + AllTypes(arg_subgroups=B(b="foo"), arg_optional=A(a=1.0)), ), ], ) -def test_replace_selections(start: DataclassT, changes: dict, selections: dict, expected: DataclassT): - actual = replace_selections( - start, changes, selections) +def test_replace_selections( + start: DataclassT, changes: dict, selections: dict, expected: DataclassT +): + actual = replace_selections(start, changes, selections) assert actual == expected @pytest.mark.parametrize( - ('start', 'changes', 'expected'), + ("start", "changes", "expected"), [ ( - AllTypes(), - {'arg_subgroups':'b',"arg_optional": A}, - AllTypes(arg_subgroups=B(), arg_optional=A()) - ), - ( - AllTypes(), - {'arg_subgroups': B,"arg_optional": A}, - AllTypes(arg_subgroups=B(), arg_optional=A()) - ), - ( - AllTypes(arg_optional=A()), - {'arg_subgroups': B,"arg_optional": None}, - AllTypes(arg_subgroups=B(), arg_optional=None) - ), - ( - AllTypes(arg_optional=A(a=1.0)), - {"arg_optional": A}, - AllTypes(arg_optional=A()) - ), - ( - AllTypes(arg_optional=None), - {"arg_optional": A(a=1.2)}, - AllTypes(arg_optional=A(a=1.2)) - ), - ( - AllTypes(arg_subgroups=A(a=1.0)), - {'arg_subgroups': 'a'}, - AllTypes(arg_subgroups=A()) + AllTypes(), + {"arg_subgroups": "b", "arg_optional": A}, + AllTypes(arg_subgroups=B(), arg_optional=A()), ), ( - AllTypes(arg_subgroups=A(a=1.0)), - None, - AllTypes(arg_subgroups=A(a=1.0)) + AllTypes(), + {"arg_subgroups": B, "arg_optional": A}, + AllTypes(arg_subgroups=B(), arg_optional=A()), ), ( - AllTypes(arg_subgroups=A(a=1.0)), - {}, - AllTypes(arg_subgroups=A(a=1.0)) + AllTypes(arg_optional=A()), + {"arg_subgroups": B, "arg_optional": None}, + AllTypes(arg_subgroups=B(), arg_optional=None), ), + (AllTypes(arg_optional=A(a=1.0)), {"arg_optional": A}, AllTypes(arg_optional=A())), + (AllTypes(arg_optional=None), {"arg_optional": A(a=1.2)}, AllTypes(arg_optional=A(a=1.2))), + (AllTypes(arg_subgroups=A(a=1.0)), {"arg_subgroups": "a"}, AllTypes(arg_subgroups=A())), + (AllTypes(arg_subgroups=A(a=1.0)), None, AllTypes(arg_subgroups=A(a=1.0))), + (AllTypes(arg_subgroups=A(a=1.0)), {}, AllTypes(arg_subgroups=A(a=1.0))), ( NestedSubgroupsConfig(), - {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}, - NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())) + {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, + NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())), ), - ] + ], ) -def test_replace_union_dataclasses(start: DataclassT, changes:dict[str, Key|DataclassT], expected: DataclassT): - assert replace_selected_dataclass(start, changes) == expected \ No newline at end of file +def test_replace_union_dataclasses( + start: DataclassT, changes: dict[str, Key | DataclassT], expected: DataclassT +): + assert replace_selected_dataclass(start, changes) == expected From d90fe3633b9a81af948bab3860ae23188892df2d Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 23:02:04 -0700 Subject: [PATCH 10/22] fix docstring unflatten_keyword --- simple_parsing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 46254d6d..5a65cd53 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -952,7 +952,7 @@ def unflatten_keyword( >>> unflatten_keyword({"a": 1, "b": 2}) {'a': {'__key__': 1}, 'b': {'__key__': 2}} - NOTE: This function expects the input to be flat. It does *not* unflatten nested dicts: + NOTE: This function expects the input to be flat. It does *not* unflatten and add keyword more than one level: """ dc = {} for k, v in flattened.items(): From cd38e0c78fe198a3196d4fd9791e2551b02e79d9 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sat, 4 Feb 2023 23:30:02 -0700 Subject: [PATCH 11/22] add test for invalid case --- simple_parsing/replace_selections.py | 11 ++++----- test/test_replace_selections.py | 36 ++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/simple_parsing/replace_selections.py b/simple_parsing/replace_selections.py index c75535dc..3af5bc20 100644 --- a/simple_parsing/replace_selections.py +++ b/simple_parsing/replace_selections.py @@ -80,8 +80,7 @@ def replace_selected_dataclass( field_value = None logger.debug("key is None") else: - logger.warn("Not Implemented") - raise TypeError(f"Not Implemented for field {field.name}!") + raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") if child_selections: new_value = replace_selected_dataclass(field_value, child_selections) @@ -104,7 +103,7 @@ def replace_selections( """Replace some values in a dataclass and replace dataclass type in nested union of dataclasses or subgroups. Compared to `simple_replace.replace`, this calls `replace_selected_dataclass` before calling `simple_parsing.replace`. - + ## Examples >>> import dataclasses >>> from simple_parsing import replace_selections, subgroups @@ -120,14 +119,14 @@ def replace_selections( ... a_or_b: Union[A, B] = subgroups({'a': A, 'b': B}, default_factory=A) ... a_or_b_union: Union[A, B] = dataclasses.field(default_factory=A) ... a_optional: Union[A, None] = None - + >>> base_config = Config(a_or_b=A(a=1)) >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": "b"}) Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) - + >>> replace_selections(base_config, {"a_or_b_union.b": "bob"}, {"a_or_b_union": B}) Config(a_or_b=A(a=1), a_or_b_union=B(b='bob'), a_optional=None) - + >>> replace_selections(base_config, {"a_optional.a": 2}, {"a_optional": A}) Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=2)) """ diff --git a/test/test_replace_selections.py b/test/test_replace_selections.py index c5a7656a..7e917557 100644 --- a/test/test_replace_selections.py +++ b/test/test_replace_selections.py @@ -106,6 +106,7 @@ class AllTypes: default="a", ) arg_optional: A | None = None + arg_optional_1: A | None = None arg_union_dataclass: A | B = field(default_factory=A) arg_union_dataclass_init_false: A | B = field(init=False) @@ -168,10 +169,21 @@ def __post_init__(self): ), ], ) +@pytest.mark.parametrize("pass_dict_as_kwargs", [True, False]) def test_replace_selections( - start: DataclassT, changes: dict, selections: dict, expected: DataclassT + start: DataclassT, + changes: dict, + selections: dict, + expected: DataclassT, + pass_dict_as_kwargs: bool, ): - actual = replace_selections(start, changes, selections) + if pass_dict_as_kwargs: + if changes is not None: + actual = replace_selections(start, selections=selections, **changes) + else: + actual = replace_selections(start, selections=selections) + else: + actual = replace_selections(start, changes, selections) assert actual == expected @@ -195,6 +207,7 @@ def test_replace_selections( ), (AllTypes(arg_optional=A(a=1.0)), {"arg_optional": A}, AllTypes(arg_optional=A())), (AllTypes(arg_optional=None), {"arg_optional": A(a=1.2)}, AllTypes(arg_optional=A(a=1.2))), + (AllTypes(arg_optional_1=A(a=1.0)), {"arg_optional_1": A}, AllTypes(arg_optional_1=A())), (AllTypes(arg_subgroups=A(a=1.0)), {"arg_subgroups": "a"}, AllTypes(arg_subgroups=A())), (AllTypes(arg_subgroups=A(a=1.0)), None, AllTypes(arg_subgroups=A(a=1.0))), (AllTypes(arg_subgroups=A(a=1.0)), {}, AllTypes(arg_subgroups=A(a=1.0))), @@ -209,3 +222,22 @@ def test_replace_union_dataclasses( start: DataclassT, changes: dict[str, Key | DataclassT], expected: DataclassT ): assert replace_selected_dataclass(start, changes) == expected + + +@pytest.mark.parametrize( + ("start", "changes", "selections", "exception_type", "match"), + [ + ( + AllTypes(arg_union_dataclass=A(a=1.0)), + {}, + {"arg_union_dataclass": "b"}, + ValueError, + "invalid selection key 'b' for field 'arg_union_dataclass'", + ), + ], +) +def test_replace_selection_invalid( + start: DataclassT, changes: dict, selections: dict, exception_type: type[Exception], match: str +): + with pytest.raises(exception_type, match=match): + replace_selections(start, changes, selections) From b7b64050f021c1c5847901bc921072fe73f466b4 Mon Sep 17 00:00:00 2001 From: zhiruiluo Date: Sun, 5 Feb 2023 10:45:47 -0700 Subject: [PATCH 12/22] modify docstring --- simple_parsing/replace_selections.py | 18 ++++++++++++------ simple_parsing/utils.py | 7 ++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/simple_parsing/replace_selections.py b/simple_parsing/replace_selections.py index 3af5bc20..2cfbb13e 100644 --- a/simple_parsing/replace_selections.py +++ b/simple_parsing/replace_selections.py @@ -107,7 +107,7 @@ def replace_selections( ## Examples >>> import dataclasses >>> from simple_parsing import replace_selections, subgroups - >>> from typing import Union + >>> from typing import Union, Optional >>> @dataclasses.dataclass ... class A: ... a: int = 0 @@ -118,17 +118,23 @@ def replace_selections( ... class Config: ... a_or_b: Union[A, B] = subgroups({'a': A, 'b': B}, default_factory=A) ... a_or_b_union: Union[A, B] = dataclasses.field(default_factory=A) - ... a_optional: Union[A, None] = None + ... a_optional: Optional[A] = None >>> base_config = Config(a_or_b=A(a=1)) + + Replace subgroups field by subgroup `Key`, dataclass type, or dataclass instance >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": "b"}) Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) - + >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": B}) + Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) + >>> replace_selections(base_config, {}, {"a_or_b": B(b="bob")}) + Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) + + Replace union of dataclasses and optional dataclass >>> replace_selections(base_config, {"a_or_b_union.b": "bob"}, {"a_or_b_union": B}) Config(a_or_b=A(a=1), a_or_b_union=B(b='bob'), a_optional=None) - - >>> replace_selections(base_config, {"a_optional.a": 2}, {"a_optional": A}) - Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=2)) + >>> replace_selections(base_config, {"a_optional.a": 10}, {"a_optional": A}) + Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=10)) """ if selections: obj = replace_selected_dataclass(obj, selections) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 5a65cd53..48098af0 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -942,17 +942,14 @@ def unflatten_keyword( flattened: Mapping[str, V], keyword: str = "__key__", sep="." ) -> PossiblyNestedDict[str, V]: """ - This function convert flattened = - into the nested dict - differentiating by the `keyword`. + This function convert flattened dict into additional layer of nested dict + and it adds the `keyword` as the key to the unpaired value. >>> unflatten_keyword({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} >>> unflatten_keyword({"a": 1, "b": 2}) {'a': {'__key__': 1}, 'b': {'__key__': 2}} - - NOTE: This function expects the input to be flat. It does *not* unflatten and add keyword more than one level: """ dc = {} for k, v in flattened.items(): From 14cf1a023c98b99a7fc0126941e426820ebbd4c4 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Mon, 27 Mar 2023 00:40:22 -0600 Subject: [PATCH 13/22] keep only replace_subgroups function and move it into replace.py --- simple_parsing/__init__.py | 5 +- simple_parsing/replace.py | 114 ++++++++++++- simple_parsing/replace_selections.py | 141 ---------------- simple_parsing/utils.py | 23 --- test/test_replace_selections.py | 243 --------------------------- test/test_replace_subgroups.py | 46 +++++ 6 files changed, 160 insertions(+), 412 deletions(-) delete mode 100644 simple_parsing/replace_selections.py delete mode 100644 test/test_replace_selections.py create mode 100644 test/test_replace_subgroups.py diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index fb1c90ce..7ba73327 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -24,8 +24,7 @@ parse, parse_known_args, ) -from .replace import replace -from .replace_selections import replace_selections +from .replace import replace, replace_subgroups from .utils import InconsistentArgumentError __all__ = [ @@ -46,7 +45,7 @@ "parse", "ParsingError", "replace", - "replace_selections", + "replace_subgroups", "Serializable", "SimpleHelpFormatter", "subgroups", diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index 3ada2b6d..ad475328 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -1,10 +1,25 @@ from __future__ import annotations import dataclasses -from typing import Any, overload +from typing import Any, overload, Mapping +import copy +from simple_parsing.annotation_utils.get_field_annotations import ( + get_field_type_from_annotations, +) +from simple_parsing.helpers.subgroups import Key from simple_parsing.utils import DataclassT, is_dataclass_instance, unflatten_split - +from simple_parsing.utils import ( + DataclassT, + contains_dataclass_type_arg, + is_dataclass_instance, + is_dataclass_type, + is_optional, + is_union, + PossiblyNestedDict, + V, + unflatten_split +) @overload def replace(obj: DataclassT, changes_dict: dict[str, Any]) -> DataclassT: @@ -90,3 +105,98 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang replace_kwargs.update(changes) return dataclasses.replace(obj, **replace_kwargs) + + +def unflatten_selection_dict( + flattened: Mapping[str, V], keyword: str = "__key__", sep="." +) -> PossiblyNestedDict[str, V]: + """ + This function convert a flattened dict into a nested dict + and it inserts the `keyword` as the selection into the nested dict. + + >>> unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} + + >>> unflatten_selection_dict({"a": 1, "b": 2}) + {'a': {'__key__': 1}, 'b': {'__key__': 2}} + """ + dc = {} + for k, v in flattened.items(): + if keyword != k and sep not in k and not isinstance(v, dict): + dc[k + sep + keyword] = v + else: + dc[k] = v + return unflatten_split(dc) + + +@overload +def replace_subgroups( + obj: DataclassT, + changes_dict: dict[str, Any], + selections: dict[str, Key | DataclassT] | None = None, +) -> DataclassT: + ... + + +@overload +def replace_subgroups( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes +) -> DataclassT: + ... + + +def replace_subgroups( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None +): + """ + This function replaces the dataclass of subgroups, union, and optional union. + The `selections` dict can be in flat format or in nested format. + + The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. + """ + keyword = "__key__" + + if selections: + selections = unflatten_selection_dict(selections, keyword) + else: + return obj + + replace_kwargs = {} + for field in dataclasses.fields(obj): + if field.name not in selections: + continue + + field_value = getattr(obj, field.name) + t = get_field_type_from_annotations(obj.__class__, field.name) + + new_value = None + # Replace subgroup is allowed when the type annotation contains dataclass + if contains_dataclass_type_arg(t): + child_selections = selections.pop(field.name) + key = child_selections.pop(keyword, None) + + if is_dataclass_type(key): + field_value = key() + elif is_dataclass_instance(key): + field_value = copy.deepcopy(key) + elif field.metadata.get("subgroups", None): + field_value = field.metadata["subgroups"][key]() + elif is_optional(t) and key is None: + field_value = None + elif contains_dataclass_type_arg(t) and key is None: + field_value = field.default_factory() + else: + raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") + + if child_selections: + new_value = replace_subgroups(field_value, child_selections) + else: + new_value = field_value + else: + raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") + + if not field.init: + raise ValueError(f"Cannot replace value of non-init field {field.name}.") + + replace_kwargs[field.name] = new_value + return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file diff --git a/simple_parsing/replace_selections.py b/simple_parsing/replace_selections.py deleted file mode 100644 index 2cfbb13e..00000000 --- a/simple_parsing/replace_selections.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -import copy -import dataclasses -import logging -from typing import Any, overload - -from simple_parsing.annotation_utils.get_field_annotations import ( - get_field_type_from_annotations, -) -from simple_parsing.helpers.subgroups import Key -from simple_parsing.replace import replace -from simple_parsing.utils import ( - DataclassT, - contains_dataclass_type_arg, - is_dataclass_instance, - is_dataclass_type, - is_optional, - is_union, - unflatten_keyword, -) - -logger = logging.getLogger(__name__) - - -@overload -def replace_selections( - obj: DataclassT, - changes_dict: dict[str, Any], - selections: dict[str, Key | DataclassT] | None = None, -) -> DataclassT: - ... - - -@overload -def replace_selections( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes -) -> DataclassT: - ... - - -def replace_selected_dataclass( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None -): - """ - This function replaces the dataclass of subgroups, union, and optional union. - The `selections` is in flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'} - - The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. - """ - keyword = "__key__" - - if selections: - selections = unflatten_keyword(selections, keyword) - else: - return obj - - replace_kwargs = {} - for field in dataclasses.fields(obj): - if field.name not in selections: - continue - - field_value = getattr(obj, field.name) - t = get_field_type_from_annotations(obj.__class__, field.name) - - if contains_dataclass_type_arg(t) and is_union(t): - child_selections = selections.pop(field.name) - key = child_selections.pop(keyword, None) - - if is_dataclass_type(key): - field_value = key() - logger.debug("is_dataclass_type") - elif is_dataclass_instance(key): - field_value = copy.deepcopy(key) - logger.debug("is_dataclass_instance") - elif field.metadata.get("subgroups", None): - field_value = field.metadata["subgroups"][key]() - logger.debug("is_subgroups") - elif is_optional(t) and key is None: - field_value = None - logger.debug("key is None") - else: - raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") - - if child_selections: - new_value = replace_selected_dataclass(field_value, child_selections) - else: - new_value = field_value - - if not field.init: - raise ValueError(f"Cannot replace value of non-init field {field.name}.") - - replace_kwargs[field.name] = new_value - return dataclasses.replace(obj, **replace_kwargs) - - -def replace_selections( - obj: DataclassT, - changes_dict: dict[str, Any] | None = None, - selections: dict[str, Key | DataclassT] | None = None, - **changes, -) -> DataclassT: - """Replace some values in a dataclass and replace dataclass type in nested union of dataclasses or subgroups. - - Compared to `simple_replace.replace`, this calls `replace_selected_dataclass` before calling `simple_parsing.replace`. - - ## Examples - >>> import dataclasses - >>> from simple_parsing import replace_selections, subgroups - >>> from typing import Union, Optional - >>> @dataclasses.dataclass - ... class A: - ... a: int = 0 - >>> @dataclasses.dataclass - ... class B: - ... b: str = "b" - >>> @dataclasses.dataclass - ... class Config: - ... a_or_b: Union[A, B] = subgroups({'a': A, 'b': B}, default_factory=A) - ... a_or_b_union: Union[A, B] = dataclasses.field(default_factory=A) - ... a_optional: Optional[A] = None - - >>> base_config = Config(a_or_b=A(a=1)) - - Replace subgroups field by subgroup `Key`, dataclass type, or dataclass instance - >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": "b"}) - Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) - >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": B}) - Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) - >>> replace_selections(base_config, {}, {"a_or_b": B(b="bob")}) - Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None) - - Replace union of dataclasses and optional dataclass - >>> replace_selections(base_config, {"a_or_b_union.b": "bob"}, {"a_or_b_union": B}) - Config(a_or_b=A(a=1), a_or_b_union=B(b='bob'), a_optional=None) - >>> replace_selections(base_config, {"a_optional.a": 10}, {"a_optional": A}) - Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=10)) - """ - if selections: - obj = replace_selected_dataclass(obj, selections) - return replace(obj, changes_dict, **changes) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 0f4863e4..fc93e450 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -939,29 +939,6 @@ def unflatten_split( return unflatten({tuple(key.split(sep)): value for key, value in flattened.items()}) -def unflatten_keyword( - flattened: Mapping[str, V], keyword: str = "__key__", sep="." -) -> PossiblyNestedDict[str, V]: - """ - This function convert flattened dict into additional layer of nested dict - and it adds the `keyword` as the key to the unpaired value. - - >>> unflatten_keyword({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) - {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} - - >>> unflatten_keyword({"a": 1, "b": 2}) - {'a': {'__key__': 1}, 'b': {'__key__': 2}} - """ - dc = {} - for k, v in flattened.items(): - if keyword != k and sep not in k: - dc[k + sep + keyword] = v - else: - dc[k] = v - - return unflatten_split(dc) - - @overload def getitem_recursive(d: PossiblyNestedDict[K, V], keys: Iterable[K]) -> V: ... diff --git a/test/test_replace_selections.py b/test/test_replace_selections.py deleted file mode 100644 index 7e917557..00000000 --- a/test/test_replace_selections.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import annotations - -import copy -import functools -import logging -from dataclasses import dataclass, field -from enum import Enum -from pathlib import Path - -import pytest - -from simple_parsing import replace_selections, subgroups -from simple_parsing.helpers.subgroups import Key -from simple_parsing.replace_selections import replace_selected_dataclass -from simple_parsing.utils import DataclassT - -logger = logging.getLogger(__name__) - - -@dataclass -class A: - a: float = 0.0 - - -@dataclass -class B: - b: str = "bar" - b_post_init: str = field(init=False) - - def __post_init__(self): - self.b_post_init = self.b + "_post" - - -@dataclass -class WithOptional: - optional_a: A | None = None - - -@dataclass -class AB: - integer_only_by_post_init: int = field(init=False) - integer_in_string: str = "1" - a_or_b: A | B = subgroups( - { - "a": A, - "a_1.23": functools.partial(A, a=1.23), - "b": B, - "b_bob": functools.partial(B, b="bob"), - }, - default="a", - ) - - def __post_init__(self): - self.integer_only_by_post_init = int(self.integer_in_string) - - -@dataclass -class C: - c: bool = False - - -@dataclass -class D: - d: int = 0 - - -@dataclass -class CD: - c_or_d: C | D = subgroups({"c": C, "d": D}, default="c") - - other_arg: str = "bob" - - -@dataclass -class NestedSubgroupsConfig: - ab_or_cd: AB | CD = subgroups( - {"ab": AB, "cd": CD}, - default_factory=AB, - ) - - -class Color(Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - -@dataclass -class AllTypes: - arg_int: int = 0 - arg_float: float = 1.0 - arg_str: str = "foo" - arg_list: list = field(default_factory=lambda: [1, 2]) - arg_dict: dict = field(default_factory=lambda: {"a": 1, "b": 2}) - arg_union: str | Path = "./" - arg_tuple: tuple[int, int] = (1, 1) - arg_enum: Color = Color.BLUE - arg_dataclass: A = field(default_factory=A) - arg_subgroups: A | B = subgroups( - { - "a": A, - "a_1.23": functools.partial(A, a=1.23), - "b": B, - "b_bob": functools.partial(B, b="bob"), - }, - default="a", - ) - arg_optional: A | None = None - arg_optional_1: A | None = None - arg_union_dataclass: A | B = field(default_factory=A) - arg_union_dataclass_init_false: A | B = field(init=False) - - def __post_init__(self): - self.arg_union_dataclass_init_false = copy.copy(self.arg_union_dataclass) - - -@pytest.mark.parametrize( - ("start", "changes", "selections", "expected"), - [ - ( - AB(), - {"a_or_b": {"a": 1.0}}, - None, - AB(a_or_b=A(a=1.0)), - ), - (AB(a_or_b=B()), {"a_or_b": {"b": "foo"}}, None, AB(a_or_b=B(b="foo"))), - (AB(), {"a_or_b": {"b": "foo"}}, {"a_or_b": "b"}, AB(a_or_b=B(b="foo"))), - ( - AB(), - {"a_or_b": B(b="bob")}, - None, - AB(a_or_b=B(b="bob")), - ), - ( - NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), - {"ab_or_cd": {"integer_in_string": "2", "a_or_b": {"b": "bob"}}}, - None, - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), - ), - ( - NestedSubgroupsConfig(ab_or_cd=AB(a_or_b=B())), - {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, - None, - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), - ), - ( - NestedSubgroupsConfig(), - {"ab_or_cd.integer_in_string": "2", "ab_or_cd.a_or_b.b": "bob"}, - {"ab_or_cd": "ab", "ab_or_cd.a_or_b": "b"}, - NestedSubgroupsConfig(ab_or_cd=AB(integer_in_string="2", a_or_b=B(b="bob"))), - ), - ( - NestedSubgroupsConfig(), - None, - {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, - NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())), - ), - ( - NestedSubgroupsConfig(), - {"ab_or_cd.c_or_d.d": 1}, - {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, - NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D(d=1))), - ), - ( - AllTypes(), - {"arg_subgroups.b": "foo", "arg_optional.a": 1.0}, - {"arg_subgroups": "b", "arg_optional": A}, - AllTypes(arg_subgroups=B(b="foo"), arg_optional=A(a=1.0)), - ), - ], -) -@pytest.mark.parametrize("pass_dict_as_kwargs", [True, False]) -def test_replace_selections( - start: DataclassT, - changes: dict, - selections: dict, - expected: DataclassT, - pass_dict_as_kwargs: bool, -): - if pass_dict_as_kwargs: - if changes is not None: - actual = replace_selections(start, selections=selections, **changes) - else: - actual = replace_selections(start, selections=selections) - else: - actual = replace_selections(start, changes, selections) - assert actual == expected - - -@pytest.mark.parametrize( - ("start", "changes", "expected"), - [ - ( - AllTypes(), - {"arg_subgroups": "b", "arg_optional": A}, - AllTypes(arg_subgroups=B(), arg_optional=A()), - ), - ( - AllTypes(), - {"arg_subgroups": B, "arg_optional": A}, - AllTypes(arg_subgroups=B(), arg_optional=A()), - ), - ( - AllTypes(arg_optional=A()), - {"arg_subgroups": B, "arg_optional": None}, - AllTypes(arg_subgroups=B(), arg_optional=None), - ), - (AllTypes(arg_optional=A(a=1.0)), {"arg_optional": A}, AllTypes(arg_optional=A())), - (AllTypes(arg_optional=None), {"arg_optional": A(a=1.2)}, AllTypes(arg_optional=A(a=1.2))), - (AllTypes(arg_optional_1=A(a=1.0)), {"arg_optional_1": A}, AllTypes(arg_optional_1=A())), - (AllTypes(arg_subgroups=A(a=1.0)), {"arg_subgroups": "a"}, AllTypes(arg_subgroups=A())), - (AllTypes(arg_subgroups=A(a=1.0)), None, AllTypes(arg_subgroups=A(a=1.0))), - (AllTypes(arg_subgroups=A(a=1.0)), {}, AllTypes(arg_subgroups=A(a=1.0))), - ( - NestedSubgroupsConfig(), - {"ab_or_cd": "cd", "ab_or_cd.c_or_d": "d"}, - NestedSubgroupsConfig(ab_or_cd=CD(c_or_d=D())), - ), - ], -) -def test_replace_union_dataclasses( - start: DataclassT, changes: dict[str, Key | DataclassT], expected: DataclassT -): - assert replace_selected_dataclass(start, changes) == expected - - -@pytest.mark.parametrize( - ("start", "changes", "selections", "exception_type", "match"), - [ - ( - AllTypes(arg_union_dataclass=A(a=1.0)), - {}, - {"arg_union_dataclass": "b"}, - ValueError, - "invalid selection key 'b' for field 'arg_union_dataclass'", - ), - ], -) -def test_replace_selection_invalid( - start: DataclassT, changes: dict, selections: dict, exception_type: type[Exception], match: str -): - with pytest.raises(exception_type, match=match): - replace_selections(start, changes, selections) diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py new file mode 100644 index 00000000..98873f47 --- /dev/null +++ b/test/test_replace_subgroups.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from simple_parsing import replace_subgroups, subgroups + + +@dataclass +class A: + a: float = 0.0 + + +@dataclass +class B: + b: str = "bar" + + +@dataclass +class AorB: + a_or_b: A | B = subgroups( + {"a": A, "b": B}, default_factory=A + ) + + +@dataclass +class Config(): + subgroup: A | B = subgroups( + {"a": A, "b": B}, default_factory=A + ) + optional: A | None = None + implicit_optional: A = None + union: A | B = field(default_factory=A) + nested_subgroup: AorB = field(default_factory=AorB) + + +def test_replace_subgroups(): + c = Config() + assert replace_subgroups(c, {'subgroup': "b"}) == Config(subgroup=B()) + assert replace_subgroups(c, {'optional': A}) == Config(optional=A()) + assert replace_subgroups(c, {'implicit_optional': A}) == \ + Config(implicit_optional=A()) + assert replace_subgroups(c, {'union': B}) == Config(union=B()) + assert replace_subgroups(c, {'nested_subgroup.a_or_b': "b"}) == \ + Config(nested_subgroup=AorB(a_or_b=B())) + assert replace_subgroups(c, {'nested_subgroup': {"a_or_b": "b"}}) == \ + Config(nested_subgroup=AorB(a_or_b=B())) From 25eb328d964be94e1390ad9f6f36d322c5285148 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Mon, 27 Mar 2023 00:42:56 -0600 Subject: [PATCH 14/22] remove @overload for replace_subgroups --- simple_parsing/replace.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index ad475328..ea8e1929 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -129,22 +129,6 @@ def unflatten_selection_dict( return unflatten_split(dc) -@overload -def replace_subgroups( - obj: DataclassT, - changes_dict: dict[str, Any], - selections: dict[str, Key | DataclassT] | None = None, -) -> DataclassT: - ... - - -@overload -def replace_subgroups( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes -) -> DataclassT: - ... - - def replace_subgroups( obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None ): From dd3d2ed9717650776a405fb6a1d2f9c0c54e9130 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Mon, 27 Mar 2023 00:47:25 -0600 Subject: [PATCH 15/22] apply pre-commit check --- simple_parsing/replace.py | 23 ++++++++++------------- test/test_replace_subgroups.py | 29 +++++++++++++---------------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index ea8e1929..85714a50 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -1,26 +1,25 @@ from __future__ import annotations -import dataclasses -from typing import Any, overload, Mapping import copy +import dataclasses +from typing import Any, Mapping, overload from simple_parsing.annotation_utils.get_field_annotations import ( get_field_type_from_annotations, ) from simple_parsing.helpers.subgroups import Key -from simple_parsing.utils import DataclassT, is_dataclass_instance, unflatten_split from simple_parsing.utils import ( DataclassT, + PossiblyNestedDict, + V, contains_dataclass_type_arg, is_dataclass_instance, is_dataclass_type, is_optional, - is_union, - PossiblyNestedDict, - V, - unflatten_split + unflatten_split, ) + @overload def replace(obj: DataclassT, changes_dict: dict[str, Any]) -> DataclassT: ... @@ -129,9 +128,7 @@ def unflatten_selection_dict( return unflatten_split(dc) -def replace_subgroups( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None -): +def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None): """ This function replaces the dataclass of subgroups, union, and optional union. The `selections` dict can be in flat format or in nested format. @@ -152,7 +149,7 @@ def replace_subgroups( field_value = getattr(obj, field.name) t = get_field_type_from_annotations(obj.__class__, field.name) - + new_value = None # Replace subgroup is allowed when the type annotation contains dataclass if contains_dataclass_type_arg(t): @@ -178,9 +175,9 @@ def replace_subgroups( new_value = field_value else: raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") - + if not field.init: raise ValueError(f"Cannot replace value of non-init field {field.name}.") replace_kwargs[field.name] = new_value - return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file + return dataclasses.replace(obj, **replace_kwargs) diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index 98873f47..b947f426 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -17,16 +17,12 @@ class B: @dataclass class AorB: - a_or_b: A | B = subgroups( - {"a": A, "b": B}, default_factory=A - ) + a_or_b: A | B = subgroups({"a": A, "b": B}, default_factory=A) @dataclass -class Config(): - subgroup: A | B = subgroups( - {"a": A, "b": B}, default_factory=A - ) +class Config: + subgroup: A | B = subgroups({"a": A, "b": B}, default_factory=A) optional: A | None = None implicit_optional: A = None union: A | B = field(default_factory=A) @@ -35,12 +31,13 @@ class Config(): def test_replace_subgroups(): c = Config() - assert replace_subgroups(c, {'subgroup': "b"}) == Config(subgroup=B()) - assert replace_subgroups(c, {'optional': A}) == Config(optional=A()) - assert replace_subgroups(c, {'implicit_optional': A}) == \ - Config(implicit_optional=A()) - assert replace_subgroups(c, {'union': B}) == Config(union=B()) - assert replace_subgroups(c, {'nested_subgroup.a_or_b': "b"}) == \ - Config(nested_subgroup=AorB(a_or_b=B())) - assert replace_subgroups(c, {'nested_subgroup': {"a_or_b": "b"}}) == \ - Config(nested_subgroup=AorB(a_or_b=B())) + assert replace_subgroups(c, {"subgroup": "b"}) == Config(subgroup=B()) + assert replace_subgroups(c, {"optional": A}) == Config(optional=A()) + assert replace_subgroups(c, {"implicit_optional": A}) == Config(implicit_optional=A()) + assert replace_subgroups(c, {"union": B}) == Config(union=B()) + assert replace_subgroups(c, {"nested_subgroup.a_or_b": "b"}) == Config( + nested_subgroup=AorB(a_or_b=B()) + ) + assert replace_subgroups(c, {"nested_subgroup": {"a_or_b": "b"}}) == Config( + nested_subgroup=AorB(a_or_b=B()) + ) From c71c6060f1d63c1c08185ae1e2a2c196217f0742 Mon Sep 17 00:00:00 2001 From: Bill Luo <50068224+zhiruiluo@users.noreply.github.com> Date: Tue, 4 Apr 2023 22:59:54 -0600 Subject: [PATCH 16/22] Update simple_parsing/replace.py Co-authored-by: Fabrice Normandin --- simple_parsing/replace.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index 85714a50..c7d24ca7 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -137,10 +137,9 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | """ keyword = "__key__" - if selections: - selections = unflatten_selection_dict(selections, keyword) - else: + if not selections: return obj + selections = unflatten_selection_dict(selections, keyword) replace_kwargs = {} for field in dataclasses.fields(obj): From 2477fec5f4c23b2af4d5935dcbad5ebbc47d583e Mon Sep 17 00:00:00 2001 From: Bill Luo <50068224+zhiruiluo@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:00:45 -0600 Subject: [PATCH 17/22] Update simple_parsing/replace.py Co-authored-by: Fabrice Normandin --- simple_parsing/replace.py | 43 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index c7d24ca7..645bf0fa 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -151,29 +151,28 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | new_value = None # Replace subgroup is allowed when the type annotation contains dataclass - if contains_dataclass_type_arg(t): - child_selections = selections.pop(field.name) - key = child_selections.pop(keyword, None) - - if is_dataclass_type(key): - field_value = key() - elif is_dataclass_instance(key): - field_value = copy.deepcopy(key) - elif field.metadata.get("subgroups", None): - field_value = field.metadata["subgroups"][key]() - elif is_optional(t) and key is None: - field_value = None - elif contains_dataclass_type_arg(t) and key is None: - field_value = field.default_factory() - else: - raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") - - if child_selections: - new_value = replace_subgroups(field_value, child_selections) - else: - new_value = field_value - else: + if not contains_dataclass_type_arg(t): raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") + child_selections = selections.pop(field.name) + key = child_selections.pop(keyword, None) + + if is_dataclass_type(key): + field_value = key() + elif is_dataclass_instance(key): + field_value = copy.deepcopy(key) + elif field.metadata.get("subgroups", None): + field_value = field.metadata["subgroups"][key]() + elif is_optional(t) and key is None: + field_value = None + elif contains_dataclass_type_arg(t) and key is None: + field_value = field.default_factory() + else: + raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") + + if child_selections: + new_value = replace_subgroups(field_value, child_selections) + else: + new_value = field_value if not field.init: raise ValueError(f"Cannot replace value of non-init field {field.name}.") From c9a68ebd5aae17685cd009a27bd526bb259defb7 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Tue, 4 Apr 2023 23:40:02 -0600 Subject: [PATCH 18/22] address some issues --- simple_parsing/replace.py | 39 +++++++-- simple_parsing/replace_subgroups.py | 126 ++++++++++++++++++++++++++++ test/test_replace_subgroups.py | 9 ++ 3 files changed, 166 insertions(+), 8 deletions(-) create mode 100644 simple_parsing/replace_subgroups.py diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index 645bf0fa..cb35daca 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -117,11 +117,22 @@ def unflatten_selection_dict( {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} >>> unflatten_selection_dict({"a": 1, "b": 2}) - {'a': {'__key__': 1}, 'b': {'__key__': 2}} + {'a': 1, 'b': 2} """ dc = {} + + existing_top_level_keys = set() + conflited_top_level_keys = set() for k, v in flattened.items(): - if keyword != k and sep not in k and not isinstance(v, dict): + top_level_key = k.split(sep)[0] + if top_level_key not in existing_top_level_keys: + existing_top_level_keys.add(top_level_key) + else: + conflited_top_level_keys.add(top_level_key) + + for k, v in flattened.items(): + # if keyword != k and sep not in k and not isinstance(v, dict): + if k in conflited_top_level_keys: dc[k + sep + keyword] = v else: dc[k] = v @@ -143,6 +154,9 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | replace_kwargs = {} for field in dataclasses.fields(obj): + if not field.init: + raise ValueError(f"Cannot replace value of non-init field {field.name}.") + if field.name not in selections: continue @@ -153,15 +167,27 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | # Replace subgroup is allowed when the type annotation contains dataclass if not contains_dataclass_type_arg(t): raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") - child_selections = selections.pop(field.name) - key = child_selections.pop(keyword, None) + + selection = selections.pop(field.name) + if isinstance(selection, dict): + key = selection.pop(keyword, None) + child_selections = selection + else: + key = selection + child_selections = None if is_dataclass_type(key): field_value = key() elif is_dataclass_instance(key): field_value = copy.deepcopy(key) elif field.metadata.get("subgroups", None): - field_value = field.metadata["subgroups"][key]() + subgroup_selection = field.metadata["subgroups"][key] + if is_dataclass_instance(subgroup_selection): + # when the subgroup selection is a frozen dataclass instance + field_value = subgroup_selection + else: + # when the subgroup selection is a dataclass type + field_value = field.metadata["subgroups"][key]() elif is_optional(t) and key is None: field_value = None elif contains_dataclass_type_arg(t) and key is None: @@ -174,8 +200,5 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | else: new_value = field_value - if not field.init: - raise ValueError(f"Cannot replace value of non-init field {field.name}.") - replace_kwargs[field.name] = new_value return dataclasses.replace(obj, **replace_kwargs) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py new file mode 100644 index 00000000..67a11566 --- /dev/null +++ b/simple_parsing/replace_subgroups.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import copy +import dataclasses +import logging +from typing import Any, overload, Mapping + +from simple_parsing.annotation_utils.get_field_annotations import ( + get_field_type_from_annotations, +) +from simple_parsing.helpers.subgroups import Key +from simple_parsing.utils import ( + DataclassT, + contains_dataclass_type_arg, + is_dataclass_instance, + is_dataclass_type, + is_optional, + is_union, + PossiblyNestedDict, + V, + unflatten_split +) + +logger = logging.getLogger(__name__) + + + +def unflatten_selection_dict( + flattened: Mapping[str, V], keyword: str = "__key__", sep="." +) -> PossiblyNestedDict[str, V]: + """ + This function convert a flattened dict into a nested dict + and it inserts the `keyword` as the selection into the nested dict. + + >>> unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} + + >>> unflatten_selection_dict({"a": 1, "b": 2}) + {'a': {'__key__': 1}, 'b': {'__key__': 2}} + """ + dc = {} + for k, v in flattened.items(): + if keyword != k and sep not in k and not isinstance(v, dict): + dc[k + sep + keyword] = v + else: + dc[k] = v + logger.debug(dc) + return unflatten_split(dc) + + +@overload +def replace_subgroups( + obj: DataclassT, + changes_dict: dict[str, Any], + selections: dict[str, Key | DataclassT] | None = None, +) -> DataclassT: + ... + + +@overload +def replace_subgroups( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes +) -> DataclassT: + ... + + +def replace_subgroups( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None +): + """ + This function replaces the dataclass of subgroups, union, and optional union. + The `selections` dict can be in flat format or in nested format. + + The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. + """ + keyword = "__key__" + + if selections: + selections = unflatten_selection_dict(selections, keyword) + else: + return obj + + replace_kwargs = {} + for field in dataclasses.fields(obj): + if field.name not in selections: + continue + + field_value = getattr(obj, field.name) + t = get_field_type_from_annotations(obj.__class__, field.name) + + new_value = None + # Replace subgroup is allowed when the type annotation contains dataclass + if contains_dataclass_type_arg(t): + child_selections = selections.pop(field.name) + key = child_selections.pop(keyword, None) + + if is_dataclass_type(key): + field_value = key() + logger.debug("is_dataclass_type") + elif is_dataclass_instance(key): + field_value = copy.deepcopy(key) + logger.debug("is_dataclass_instance") + elif field.metadata.get("subgroups", None): + field_value = field.metadata["subgroups"][key]() + logger.debug("is_subgroups") + elif is_optional(t) and key is None: + field_value = None + logger.debug("key is None") + elif contains_dataclass_type_arg(t) and key is None: + field_value = field.default_factory() + logger.debug(f"nested_dataclass") + else: + raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") + + if child_selections: + new_value = replace_subgroups(field_value, child_selections) + else: + new_value = field_value + else: + raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") + + if not field.init: + raise ValueError(f"Cannot replace value of non-init field {field.name}.") + + replace_kwargs[field.name] = new_value + return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index b947f426..9863a8bc 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -19,10 +19,18 @@ class B: class AorB: a_or_b: A | B = subgroups({"a": A, "b": B}, default_factory=A) +@dataclass(frozen=True) +class FrozenConfig: + a: int = 1 + b: str = "bob" + +odd = FrozenConfig(a=1, b="odd") +even = FrozenConfig(a=2, b="even") @dataclass class Config: subgroup: A | B = subgroups({"a": A, "b": B}, default_factory=A) + frozen_subgroup: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) optional: A | None = None implicit_optional: A = None union: A | B = field(default_factory=A) @@ -32,6 +40,7 @@ class Config: def test_replace_subgroups(): c = Config() assert replace_subgroups(c, {"subgroup": "b"}) == Config(subgroup=B()) + assert replace_subgroups(c, {"frozen_subgroup": "odd"}) == Config(frozen_subgroup=odd) assert replace_subgroups(c, {"optional": A}) == Config(optional=A()) assert replace_subgroups(c, {"implicit_optional": A}) == Config(implicit_optional=A()) assert replace_subgroups(c, {"union": B}) == Config(union=B()) From 5973d054d4930132c211be06481e3f6779338bc1 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Tue, 4 Apr 2023 23:57:13 -0600 Subject: [PATCH 19/22] rename some variables --- simple_parsing/replace.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index cb35daca..71eb3e0a 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -161,39 +161,40 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | continue field_value = getattr(obj, field.name) - t = get_field_type_from_annotations(obj.__class__, field.name) + field_annotation = get_field_type_from_annotations(obj.__class__, field.name) new_value = None # Replace subgroup is allowed when the type annotation contains dataclass - if not contains_dataclass_type_arg(t): - raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") + if not contains_dataclass_type_arg(field_annotation): + raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {field_annotation}") selection = selections.pop(field.name) if isinstance(selection, dict): - key = selection.pop(keyword, None) + value_of_selection = selection.pop(keyword, None) child_selections = selection else: - key = selection + value_of_selection = selection child_selections = None - if is_dataclass_type(key): - field_value = key() - elif is_dataclass_instance(key): - field_value = copy.deepcopy(key) + if is_dataclass_type(value_of_selection): + field_value = value_of_selection() + elif is_dataclass_instance(value_of_selection): + field_value = copy.deepcopy(value_of_selection) elif field.metadata.get("subgroups", None): - subgroup_selection = field.metadata["subgroups"][key] + assert isinstance(value_of_selection, str) + subgroup_selection = field.metadata["subgroups"][value_of_selection] if is_dataclass_instance(subgroup_selection): # when the subgroup selection is a frozen dataclass instance field_value = subgroup_selection else: # when the subgroup selection is a dataclass type - field_value = field.metadata["subgroups"][key]() - elif is_optional(t) and key is None: + field_value = field.metadata["subgroups"][value_of_selection]() + elif is_optional(field_annotation) and value_of_selection is None: field_value = None - elif contains_dataclass_type_arg(t) and key is None: + elif contains_dataclass_type_arg(field_annotation) and value_of_selection is None: field_value = field.default_factory() else: - raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") + raise ValueError(f"invalid selection key '{value_of_selection}' for field '{field.name}'") if child_selections: new_value = replace_subgroups(field_value, child_selections) From 37814faef64590a29d9485e6e59be910bb53e862 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Wed, 5 Apr 2023 00:04:56 -0600 Subject: [PATCH 20/22] rm replace_subgroups.py --- simple_parsing/replace.py | 16 ++-- simple_parsing/replace_subgroups.py | 126 ---------------------------- test/test_replace_subgroups.py | 3 + 3 files changed, 13 insertions(+), 132 deletions(-) delete mode 100644 simple_parsing/replace_subgroups.py diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index 71eb3e0a..b70a57ea 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -120,7 +120,7 @@ def unflatten_selection_dict( {'a': 1, 'b': 2} """ dc = {} - + existing_top_level_keys = set() conflited_top_level_keys = set() for k, v in flattened.items(): @@ -129,7 +129,7 @@ def unflatten_selection_dict( existing_top_level_keys.add(top_level_key) else: conflited_top_level_keys.add(top_level_key) - + for k, v in flattened.items(): # if keyword != k and sep not in k and not isinstance(v, dict): if k in conflited_top_level_keys: @@ -156,7 +156,7 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | for field in dataclasses.fields(obj): if not field.init: raise ValueError(f"Cannot replace value of non-init field {field.name}.") - + if field.name not in selections: continue @@ -166,8 +166,10 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | new_value = None # Replace subgroup is allowed when the type annotation contains dataclass if not contains_dataclass_type_arg(field_annotation): - raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {field_annotation}") - + raise ValueError( + f"The replaced subgroups contains no dataclass in its annotation {field_annotation}" + ) + selection = selections.pop(field.name) if isinstance(selection, dict): value_of_selection = selection.pop(keyword, None) @@ -194,7 +196,9 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | elif contains_dataclass_type_arg(field_annotation) and value_of_selection is None: field_value = field.default_factory() else: - raise ValueError(f"invalid selection key '{value_of_selection}' for field '{field.name}'") + raise ValueError( + f"invalid selection key '{value_of_selection}' for field '{field.name}'" + ) if child_selections: new_value = replace_subgroups(field_value, child_selections) diff --git a/simple_parsing/replace_subgroups.py b/simple_parsing/replace_subgroups.py deleted file mode 100644 index 67a11566..00000000 --- a/simple_parsing/replace_subgroups.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -import copy -import dataclasses -import logging -from typing import Any, overload, Mapping - -from simple_parsing.annotation_utils.get_field_annotations import ( - get_field_type_from_annotations, -) -from simple_parsing.helpers.subgroups import Key -from simple_parsing.utils import ( - DataclassT, - contains_dataclass_type_arg, - is_dataclass_instance, - is_dataclass_type, - is_optional, - is_union, - PossiblyNestedDict, - V, - unflatten_split -) - -logger = logging.getLogger(__name__) - - - -def unflatten_selection_dict( - flattened: Mapping[str, V], keyword: str = "__key__", sep="." -) -> PossiblyNestedDict[str, V]: - """ - This function convert a flattened dict into a nested dict - and it inserts the `keyword` as the selection into the nested dict. - - >>> unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) - {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} - - >>> unflatten_selection_dict({"a": 1, "b": 2}) - {'a': {'__key__': 1}, 'b': {'__key__': 2}} - """ - dc = {} - for k, v in flattened.items(): - if keyword != k and sep not in k and not isinstance(v, dict): - dc[k + sep + keyword] = v - else: - dc[k] = v - logger.debug(dc) - return unflatten_split(dc) - - -@overload -def replace_subgroups( - obj: DataclassT, - changes_dict: dict[str, Any], - selections: dict[str, Key | DataclassT] | None = None, -) -> DataclassT: - ... - - -@overload -def replace_subgroups( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None, **changes -) -> DataclassT: - ... - - -def replace_subgroups( - obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None -): - """ - This function replaces the dataclass of subgroups, union, and optional union. - The `selections` dict can be in flat format or in nested format. - - The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. - """ - keyword = "__key__" - - if selections: - selections = unflatten_selection_dict(selections, keyword) - else: - return obj - - replace_kwargs = {} - for field in dataclasses.fields(obj): - if field.name not in selections: - continue - - field_value = getattr(obj, field.name) - t = get_field_type_from_annotations(obj.__class__, field.name) - - new_value = None - # Replace subgroup is allowed when the type annotation contains dataclass - if contains_dataclass_type_arg(t): - child_selections = selections.pop(field.name) - key = child_selections.pop(keyword, None) - - if is_dataclass_type(key): - field_value = key() - logger.debug("is_dataclass_type") - elif is_dataclass_instance(key): - field_value = copy.deepcopy(key) - logger.debug("is_dataclass_instance") - elif field.metadata.get("subgroups", None): - field_value = field.metadata["subgroups"][key]() - logger.debug("is_subgroups") - elif is_optional(t) and key is None: - field_value = None - logger.debug("key is None") - elif contains_dataclass_type_arg(t) and key is None: - field_value = field.default_factory() - logger.debug(f"nested_dataclass") - else: - raise ValueError(f"invalid selection key '{key}' for field '{field.name}'") - - if child_selections: - new_value = replace_subgroups(field_value, child_selections) - else: - new_value = field_value - else: - raise ValueError(f"The replaced subgroups contains no dataclass in its annotation {t}") - - if not field.init: - raise ValueError(f"Cannot replace value of non-init field {field.name}.") - - replace_kwargs[field.name] = new_value - return dataclasses.replace(obj, **replace_kwargs) \ No newline at end of file diff --git a/test/test_replace_subgroups.py b/test/test_replace_subgroups.py index 9863a8bc..6c64a917 100644 --- a/test/test_replace_subgroups.py +++ b/test/test_replace_subgroups.py @@ -19,14 +19,17 @@ class B: class AorB: a_or_b: A | B = subgroups({"a": A, "b": B}, default_factory=A) + @dataclass(frozen=True) class FrozenConfig: a: int = 1 b: str = "bob" + odd = FrozenConfig(a=1, b="odd") even = FrozenConfig(a=2, b="even") + @dataclass class Config: subgroup: A | B = subgroups({"a": A, "b": B}, default_factory=A) From 74b52a2b85d2d12bf5caa3924e9c69cc9607c5a3 Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Thu, 6 Apr 2023 09:46:11 -0600 Subject: [PATCH 21/22] Fix unflatten_selection_dict --- simple_parsing/replace.py | 48 +++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index b70a57ea..e482a639 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -18,7 +18,8 @@ is_optional, unflatten_split, ) - +import logging +logger = logging.getLogger(__name__) @overload def replace(obj: DataclassT, changes_dict: dict[str, Any]) -> DataclassT: @@ -107,7 +108,7 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang def unflatten_selection_dict( - flattened: Mapping[str, V], keyword: str = "__key__", sep="." + flattened: Mapping[str, V], keyword: str = "__key__", sep: str =".", recursive: bool = False ) -> PossiblyNestedDict[str, V]: """ This function convert a flattened dict into a nested dict @@ -115,31 +116,50 @@ def unflatten_selection_dict( >>> unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} + + >>> unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}) + {'lv1': {'__key__': 'a', 'lv2': 'b', 'lv2.lv3': 'c'}} + + >>> unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}, recursive=True) + {'lv1': {'__key__': 'a', 'lv2': {'__key__': 'b', 'lv3': 'c'}}} + + >>> unflatten_selection_dict({'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'c_or_d': 'd'}} >>> unflatten_selection_dict({"a": 1, "b": 2}) {'a': 1, 'b': 2} """ dc = {} - existing_top_level_keys = set() - conflited_top_level_keys = set() + unflatten_those_top_level_keys = set() for k, v in flattened.items(): - top_level_key = k.split(sep)[0] - if top_level_key not in existing_top_level_keys: - existing_top_level_keys.add(top_level_key) - else: - conflited_top_level_keys.add(top_level_key) + splited_keys = k.split(sep) + if len(splited_keys) >= 2: + unflatten_those_top_level_keys.add(splited_keys[0]) for k, v in flattened.items(): - # if keyword != k and sep not in k and not isinstance(v, dict): - if k in conflited_top_level_keys: - dc[k + sep + keyword] = v + keys = k.split(sep) + top_level_key = keys[0] + rest_keys = keys[1:] + if top_level_key in unflatten_those_top_level_keys: + sub_dc = dc.get(top_level_key, {}) + if len(rest_keys) == 0: + sub_dc[keyword] = v + else: + sub_dc['.'.join(rest_keys)] = v + dc[top_level_key] = sub_dc else: dc[k] = v - return unflatten_split(dc) + + if recursive: + for k in unflatten_those_top_level_keys: + v = dc.pop(k) + unflatten_v = unflatten_selection_dict(v, recursive=recursive) + dc[k] = unflatten_v + return dc -def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None): +def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None) -> DataclassT: """ This function replaces the dataclass of subgroups, union, and optional union. The `selections` dict can be in flat format or in nested format. From 4b249833894ccec260a46f3ef410d6e1068526ca Mon Sep 17 00:00:00 2001 From: lzrpotato Date: Thu, 1 Jun 2023 13:33:09 -0600 Subject: [PATCH 22/22] change as per suggestion --- simple_parsing/replace.py | 114 ++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 55 deletions(-) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index e482a639..3735a336 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -2,6 +2,7 @@ import copy import dataclasses +import logging from typing import Any, Mapping, overload from simple_parsing.annotation_utils.get_field_annotations import ( @@ -18,9 +19,10 @@ is_optional, unflatten_split, ) -import logging + logger = logging.getLogger(__name__) + @overload def replace(obj: DataclassT, changes_dict: dict[str, Any]) -> DataclassT: ... @@ -107,59 +109,9 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang return dataclasses.replace(obj, **replace_kwargs) -def unflatten_selection_dict( - flattened: Mapping[str, V], keyword: str = "__key__", sep: str =".", recursive: bool = False -) -> PossiblyNestedDict[str, V]: - """ - This function convert a flattened dict into a nested dict - and it inserts the `keyword` as the selection into the nested dict. - - >>> unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) - {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} - - >>> unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}) - {'lv1': {'__key__': 'a', 'lv2': 'b', 'lv2.lv3': 'c'}} - - >>> unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}, recursive=True) - {'lv1': {'__key__': 'a', 'lv2': {'__key__': 'b', 'lv3': 'c'}}} - - >>> unflatten_selection_dict({'ab_or_cd.c_or_d': 'd'}) - {'ab_or_cd': {'c_or_d': 'd'}} - - >>> unflatten_selection_dict({"a": 1, "b": 2}) - {'a': 1, 'b': 2} - """ - dc = {} - - unflatten_those_top_level_keys = set() - for k, v in flattened.items(): - splited_keys = k.split(sep) - if len(splited_keys) >= 2: - unflatten_those_top_level_keys.add(splited_keys[0]) - - for k, v in flattened.items(): - keys = k.split(sep) - top_level_key = keys[0] - rest_keys = keys[1:] - if top_level_key in unflatten_those_top_level_keys: - sub_dc = dc.get(top_level_key, {}) - if len(rest_keys) == 0: - sub_dc[keyword] = v - else: - sub_dc['.'.join(rest_keys)] = v - dc[top_level_key] = sub_dc - else: - dc[k] = v - - if recursive: - for k in unflatten_those_top_level_keys: - v = dc.pop(k) - unflatten_v = unflatten_selection_dict(v, recursive=recursive) - dc[k] = unflatten_v - return dc - - -def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None) -> DataclassT: +def replace_subgroups( + obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None +) -> DataclassT: """ This function replaces the dataclass of subgroups, union, and optional union. The `selections` dict can be in flat format or in nested format. @@ -170,7 +122,7 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | if not selections: return obj - selections = unflatten_selection_dict(selections, keyword) + selections = _unflatten_selection_dict(selections, keyword, recursive=False) replace_kwargs = {} for field in dataclasses.fields(obj): @@ -227,3 +179,55 @@ def replace_subgroups(obj: DataclassT, selections: dict[str, Key | DataclassT] | replace_kwargs[field.name] = new_value return dataclasses.replace(obj, **replace_kwargs) + + +def _unflatten_selection_dict( + flattened: Mapping[str, V], keyword: str = "__key__", sep: str = ".", recursive: bool = True +) -> PossiblyNestedDict[str, V]: + """ + This function convert a flattened dict into a nested dict + and it inserts the `keyword` as the selection into the nested dict. + + >>> _unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} + + >>> _unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}) + {'lv1': {'__key__': 'a', 'lv2': {'__key__': 'b', 'lv3': 'c'}}} + + >>> _unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'}, recursive=False) + {'lv1': {'__key__': 'a', 'lv2': 'b', 'lv2.lv3': 'c'}} + + >>> _unflatten_selection_dict({'ab_or_cd.c_or_d': 'd'}) + {'ab_or_cd': {'c_or_d': 'd'}} + + >>> _unflatten_selection_dict({"a": 1, "b": 2}) + {'a': 1, 'b': 2} + """ + dc = {} + + unflatten_those_top_level_keys = set() + for k, v in flattened.items(): + splited_keys = k.split(sep) + if len(splited_keys) >= 2: + unflatten_those_top_level_keys.add(splited_keys[0]) + + for k, v in flattened.items(): + keys = k.split(sep) + top_level_key = keys[0] + rest_keys = keys[1:] + if top_level_key in unflatten_those_top_level_keys: + sub_dc = dc.get(top_level_key, {}) + if len(rest_keys) == 0: + sub_dc[keyword] = v + else: + sub_dc[".".join(rest_keys)] = v + dc[top_level_key] = sub_dc + else: + dc[k] = v + + if recursive: + for k in unflatten_those_top_level_keys: + v = dc.pop(k) + unflatten_v = _unflatten_selection_dict(v, recursive=recursive) + dc[k] = unflatten_v + return dc