Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace_subgroups function #215

Merged
merged 31 commits into from
Jul 10, 2023
Merged

Conversation

zhiruiluo
Copy link
Contributor

@zhiruiluo zhiruiluo commented Feb 4, 2023

This PR is to add the replace_selections function that replaces some values in a dataclass and replaces dataclass type in nested union of dataclasses or subgroups.

simple_parsing.replace_selections function support replacingeplaces dataclass type in nested union of dataclasses or subgroups in addition to simple_parsing.replace #212.

  • signature: replace_selections(obj: DataclassT, changes_dict: dict[str, Any] | None = None, selections: dict[str, Key | DataclassT] | None = None, **changes) -> DataclassT:
    • Compared to simple_replace.replace, this calls replace_selected_dataclass before calling simple_parsing.replace.
  • helper function replace_selected_dataclass
    • This function replaces the dataclass of subgroups, union, and optional union. The selections is in flat format, e.g. {"ab_or_cd": 'cd', "ab_or_cd.c_or_d": 'd'}
    • The values of selections can be Key of subgroups, dataclass type, and dataclass instance.

Open question:

  • Do we combine the simple_parsing.replace_selections and simple_parsing.replace?

Examples

    >>> import dataclasses
    >>> from simple_parsing import replace_selections, subgroups
    >>> from typing import Union
    >>> @dataclasses.dataclass
    ... class A:
    ...     a: int = 0
    >>> @dataclasses.dataclass
    ... class B:
    ...     b: str = "b"
    >>> @dataclasses.dataclass
    ... class Config:
    ...     a_or_b: Union[A, B] = subgroups({'a': A, 'b': B}, default_factory=A)
    ...     a_or_b_union: Union[A, B] = dataclasses.field(default_factory=A)
    ...     a_optional: Union[A, None] = None
    
    >>> base_config = Config(a_or_b=A(a=1))
    # Replace subgroups field by subgroup `Key`, dataclass type, or dataclass instance
    >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": "b"})
    Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None)
    >>> replace_selections(base_config, {"a_or_b.b": "bob"}, {"a_or_b": B})
    Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None)
    >>> replace_selections(base_config, {}, {"a_or_b": B(b="bob")})
    Config(a_or_b=B(b='bob'), a_or_b_union=A(a=0), a_optional=None)
    
    # Replace union of dataclasses and optional dataclass
    >>> replace_selections(base_config, {"a_or_b_union.b": "bob"}, {"a_or_b_union": B})
    Config(a_or_b=A(a=1), a_or_b_union=B(b='bob'), a_optional=None)
    >>> replace_selections(base_config, {"a_optional.a": 10}, {"a_optional": A})
    Config(a_or_b=A(a=1), a_or_b_union=A(a=0), a_optional=A(a=10))

@zhiruiluo zhiruiluo marked this pull request as draft February 4, 2023 18:21
@zhiruiluo
Copy link
Contributor Author

zhiruiluo commented Feb 4, 2023

WIP

@zhiruiluo zhiruiluo changed the title Add replace_subgroups function Add replace_selections function Feb 5, 2023
@zhiruiluo zhiruiluo marked this pull request as ready for review February 5, 2023 06:02
Copy link
Owner

@lebrice lebrice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zhiruiluo, interesting PR.

Left some comments. let me know what you think.

simple_parsing/replace_selections.py Outdated Show resolved Hide resolved
test/test_replace_selections.py Outdated Show resolved Hide resolved
simple_parsing/utils.py Outdated Show resolved Hide resolved
simple_parsing/replace_selections.py Outdated Show resolved Hide resolved
@zhiruiluo
Copy link
Contributor Author

Hi @lebrice,
Sorry for letting you wait too long.

Here is the updating:

  • keep only replace_subgroups function
  • simplify test cases
  • move replace_subgroups function into replace.py file
  • move the unflatten function out of utils.py

@zhiruiluo zhiruiluo requested a review from lebrice March 27, 2023 06:59
Copy link
Owner

@lebrice lebrice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zhiruiluo , this is a lot clearer!

There are just a few small things to tweak, but this looks pretty good!

simple_parsing/replace.py Outdated Show resolved Hide resolved
simple_parsing/replace.py Outdated Show resolved Hide resolved
simple_parsing/replace.py Outdated Show resolved Hide resolved
simple_parsing/replace.py Outdated Show resolved Hide resolved
simple_parsing/replace.py Outdated Show resolved Hide resolved
@zhiruiluo zhiruiluo requested a review from lebrice April 5, 2023 05:57
@zhiruiluo zhiruiluo changed the title Add replace_selections function Add replace_subgroups function Apr 5, 2023
@zhiruiluo
Copy link
Contributor Author

Hi @lebrice, Here are some updates so far:

  • Fixed the bug in unflatten_selection_dict when no overlapping of keys and added a recursive flag.
  • reordered some logics for readability
  • add test case for the frozen dataclass to work

I would appreciate any further feedback or suggestions for improvement that you may have. Thank you.

Copy link
Owner

@lebrice lebrice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zhiruiluo , and sorry it took so long to review this.

This looks good to me. I only have a tiny comment on the naming of one of the functions, but appart from that, this is good to merge!

Thanks again!

Comment on lines 120 to 121
>>> unflatten_selection_dict({'lv1': 'a', 'lv1.lv2': 'b', 'lv1.lv2.lv3': 'c'})
{'lv1': {'__key__': 'a', 'lv2': 'b', 'lv2.lv3': 'c'}}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest making recursive=True by default. The non-recursive case seems a bit strange to me.
Also, is this function expected to be used by users? If not, I'd prefix it with an underscore (_unflatten_selection_dict) to signal that it is an internal function only used in this module.
I would also place it below other "public" functions like replace and replace_subgroups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lebrice, I have made changes for all your suggestions.
But the test failed due to the Literal issue #259. Once #260 merged, this PR will pass the test.

@lebrice lebrice merged commit d3b704a into lebrice:master Jul 10, 2023
@zhiruiluo zhiruiluo deleted the add_replace_subgroups branch July 10, 2023 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants