-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for str versions of dicts based on typing #30227
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts figured out where to add a test, and verified that all 3 conditions raise their respective errors :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for enabling this!
src/transformers/training_args.py
Outdated
@@ -1380,6 +1390,14 @@ class TrainingArguments: | |||
) | |||
|
|||
def __post_init__(self): | |||
# Parse in args that could be `dict` sent in from the CLI as a string | |||
for field in VALID_DICT_FIELDS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like just checking for a few possible args!
@@ -405,3 +406,58 @@ def test_parse_yaml(self): | |||
def test_integration_training_args(self): | |||
parser = HfArgumentParser(TrainingArguments) | |||
self.assertIsNotNone(parser) | |||
|
|||
def test_valid_dict_annotation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know on offline discussion I said we probably don't need a test to check the parsing. Seeing the implementation, i.e. TrainingArguments
can be created with a field as a string representation of a dict, and not having to include the CLI, I think we can add a simple for at least one of the fields in VALID_DICT_FIELDS
e.g. something along the lines of:
def test_valid_dict_input_parsing(self):
args = TrainingArguments(
field_name='{"key": value}'
)
# Or however it's assigned in the args
self.assertEqual(args.field_name, {key: value})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. This also made me notice that we cast int and bools as str
still, so added a helper for this (and does so to avoid literal_eval
, which can be exploited)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @muellerzr for adding this super useful feature to be able to pass string representation of dict as cmd arguments! 🚀
This should be very useful for the CLI support that @younesbelkada worked wrt TRL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for enabling this and adding these tests ❤️
elif isinstance(value, str): | ||
# First check for bool and convert | ||
if value.lower() in ("true", "false"): | ||
passed_value[key] = value.lower() == "true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}', | ||
) | ||
self.assertEqual(args.accelerator_config.split_batches, True) | ||
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
What does this PR do?
This PR adds support for passing in a string dictionary to arguments that allow a dict in the argparser. For example:
(
test.py
is just a small script reading the args from the parsers):This also fixes issues with typing of not being able to state that
deepspeed
andfsdp_config
can't have typedict
. Turns out it's a (very) fun setting with the argparser wrapper we have for these. Types should be declared as (for this):This is very important to get working properly!
Fixes #30204
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts @pacman100
Technically everything will go brr if passes, since any call to the CLI/parser happens usually with examples etc. So new tests added, but did manually verify input/output from a basic script locally