Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for str versions of dicts based on typing #30227

Merged
merged 10 commits into from
Apr 16, 2024

Conversation

muellerzr
Copy link
Contributor

What does this PR do?

This PR adds support for passing in a string dictionary to arguments that allow a dict in the argparser. For example:

python test.py --output_dir testdir --accelerator_config='{"dispatch_batches":"False"}'

(test.py is just a small script reading the args from the parsers):

from transformers import HfArgumentParser, TrainingArguments

parser = HfArgumentParser((TrainingArguments))

training_args = parser.parse_args_into_dataclasses()

print(training_args)

This also fixes issues with typing of not being able to state that deepspeed and fsdp_config can't have type dict. Turns out it's a (very) fun setting with the argparser wrapper we have for these. Types should be declared as (for this):

Optional[Union[dict,str,...]]

This is very important to get working properly!

Fixes #30204

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts @pacman100

Technically everything will go brr if passes, since any call to the CLI/parser happens usually with examples etc. So new tests added, but did manually verify input/output from a basic script locally

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor Author

@amyeroberts figured out where to add a test, and verified that all 3 conditions raise their respective errors :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for enabling this!

src/transformers/training_args.py Outdated Show resolved Hide resolved
@@ -1380,6 +1390,14 @@ class TrainingArguments:
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in VALID_DICT_FIELDS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like just checking for a few possible args!

@@ -405,3 +406,58 @@ def test_parse_yaml(self):
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)

def test_valid_dict_annotation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know on offline discussion I said we probably don't need a test to check the parsing. Seeing the implementation, i.e. TrainingArguments can be created with a field as a string representation of a dict, and not having to include the CLI, I think we can add a simple for at least one of the fields in VALID_DICT_FIELDS

e.g. something along the lines of:

def test_valid_dict_input_parsing(self):
    args = TrainingArguments(
        field_name='{"key": value}'
    )
    # Or however it's assigned in the args
    self.assertEqual(args.field_name, {key: value})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This also made me notice that we cast int and bools as str still, so added a helper for this (and does so to avoid literal_eval, which can be exploited)

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @muellerzr for adding this super useful feature to be able to pass string representation of dict as cmd arguments! 🚀

This should be very useful for the CLI support that @younesbelkada worked wrt TRL.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for enabling this and adding these tests ❤️

elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
)
self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@muellerzr muellerzr merged commit 487505f into main Apr 16, 2024
21 checks passed
@muellerzr muellerzr deleted the muellerzr-allow-dict-parsing branch April 16, 2024 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhance HfArgumentParser with Dict command-line parser
4 participants