From db72d1aed7288b368af209efae8875d7b2021597 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 10:37:16 -0400 Subject: [PATCH 01/10] Bookmark, initial impelemtation. Need to test --- src/transformers/training_args.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cdf6325c4b4a..5438bb40a02d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -22,7 +22,7 @@ from datetime import timedelta from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, get_args, get_origin from huggingface_hub import get_full_repo_name from packaging import version @@ -1393,6 +1393,21 @@ def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN + # for any and all args that can use a `dict`, check if the value is a `str` and if so, load it as a `dict` + dict_fields = [] + for name, field in self.__dataclass_fields__.items(): + # `Optional` winds up being a `Union` when digging through the types + if get_origin(field.type) == Union: + # Check if raw `dict` types are in any of its values + if any(arg in (dict, Dict) for arg in get_args(field.type)): + # If found, add to the list of fields that can be loaded as a `dict` + dict_fields.append(name) + + # Next parse in the `dict` fields + for name in dict_fields: + if isinstance(getattr(self, name), str) and getattr(self, name).startswith("{"): + setattr(self, name, json.loads(getattr(self, name))) + if isinstance(self.evaluation_strategy, EvaluationStrategy): warnings.warn( "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" @@ -1774,6 +1789,7 @@ def __post_init__(self): if not isinstance(self.accelerator_config, (AcceleratorConfig)): if self.accelerator_config is None: self.accelerator_config = AcceleratorConfig() + # Case: passed in as a str, could be either a `path` or raw dict elif isinstance(self.accelerator_config, dict): self.accelerator_config = AcceleratorConfig(**self.accelerator_config) else: From 479814906eace1e20aeabdef7c86ffa63b349031 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 13:03:26 -0400 Subject: [PATCH 02/10] Clean --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5438bb40a02d..b5d283ca11a1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1401,8 +1401,8 @@ def __post_init__(self): # Check if raw `dict` types are in any of its values if any(arg in (dict, Dict) for arg in get_args(field.type)): # If found, add to the list of fields that can be loaded as a `dict` - dict_fields.append(name) - + dict_fields.append(name) + # Next parse in the `dict` fields for name in dict_fields: if isinstance(getattr(self, name), str) and getattr(self, name).startswith("{"): From a8e132c15dc6b0846cb939a280ac9a114aa1de53 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 14:10:28 -0400 Subject: [PATCH 03/10] Working fully, woop woop --- src/transformers/training_args.py | 58 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b5d283ca11a1..aaf4fc6b52cc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -22,7 +22,7 @@ from datetime import timedelta from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Union, get_args, get_origin +from typing import Any, Dict, List, Optional, Union from huggingface_hub import get_full_repo_name from packaging import version @@ -173,6 +173,18 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" +# Sometimes users will pass in a `str` repr of a dict in the CLI +# We need to track what fields those can be. Each time a new arg +# has a dict type, it must be added to this list +VALID_DICT_FIELDS = [ + "accelerator_config", + "fsdp_config", + "deepspeed", + "gradient_checkpointing_kwargs", + "lr_scheduler_kwargs", +] + + # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @dataclass class TrainingArguments: @@ -803,11 +815,11 @@ class TrainingArguments: default="linear", metadata={"help": "The scheduler type to use."}, ) - lr_scheduler_kwargs: Optional[Dict] = field( + lr_scheduler_kwargs: Optional[Union[dict, str]] = field( default_factory=dict, metadata={ "help": ( - "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts" + "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts." ) }, ) @@ -1118,9 +1130,8 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI fsdp_config: Optional[Union[dict, str]] = field( - default=None, + default_factory=dict, metadata={ "help": ( "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " @@ -1137,9 +1148,8 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI - accelerator_config: Optional[str] = field( - default=None, + accelerator_config: Optional[Union[AcceleratorConfig, dict, str]] = field( + default_factory=dict, metadata={ "help": ( "Config to be used with the internal Accelerator object initializtion. The value is either a " @@ -1147,9 +1157,8 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI - deepspeed: Optional[str] = field( - default=None, + deepspeed: Optional[Union[dict, str]] = field( + default_factory=dict, metadata={ "help": ( "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" @@ -1252,8 +1261,8 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) - gradient_checkpointing_kwargs: Optional[dict] = field( - default=None, + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field( + default_factory=dict, metadata={ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." }, @@ -1380,6 +1389,13 @@ class TrainingArguments: ) def __post_init__(self): + # Parse in args that could be `dict` sent in from the CLI as a string + for field in VALID_DICT_FIELDS: + # We only want to do this if the str starts with a bracket to indiciate a `dict` + # else its likely a filename if supported + if isinstance(getattr(self, field), str) and getattr(self, field).startswith("{"): + setattr(self, field, json.loads(getattr(self, field))) + # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home # see https://github.com/huggingface/transformers/issues/10628 @@ -1393,21 +1409,6 @@ def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN - # for any and all args that can use a `dict`, check if the value is a `str` and if so, load it as a `dict` - dict_fields = [] - for name, field in self.__dataclass_fields__.items(): - # `Optional` winds up being a `Union` when digging through the types - if get_origin(field.type) == Union: - # Check if raw `dict` types are in any of its values - if any(arg in (dict, Dict) for arg in get_args(field.type)): - # If found, add to the list of fields that can be loaded as a `dict` - dict_fields.append(name) - - # Next parse in the `dict` fields - for name in dict_fields: - if isinstance(getattr(self, name), str) and getattr(self, name).startswith("{"): - setattr(self, name, json.loads(getattr(self, name))) - if isinstance(self.evaluation_strategy, EvaluationStrategy): warnings.warn( "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" @@ -1789,7 +1790,6 @@ def __post_init__(self): if not isinstance(self.accelerator_config, (AcceleratorConfig)): if self.accelerator_config is None: self.accelerator_config = AcceleratorConfig() - # Case: passed in as a str, could be either a `path` or raw dict elif isinstance(self.accelerator_config, dict): self.accelerator_config = AcceleratorConfig(**self.accelerator_config) else: From 5d9a39a8e0396edc4f50199f5190617f2ba56df2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 14:21:11 -0400 Subject: [PATCH 04/10] I think working version now, testing --- src/transformers/training_args.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index aaf4fc6b52cc..c4577d14f6d8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1131,7 +1131,7 @@ class TrainingArguments: }, ) fsdp_config: Optional[Union[dict, str]] = field( - default_factory=dict, + default=None, metadata={ "help": ( "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " @@ -1149,7 +1149,7 @@ class TrainingArguments: }, ) accelerator_config: Optional[Union[AcceleratorConfig, dict, str]] = field( - default_factory=dict, + default=None, metadata={ "help": ( "Config to be used with the internal Accelerator object initializtion. The value is either a " @@ -1158,7 +1158,7 @@ class TrainingArguments: }, ) deepspeed: Optional[Union[dict, str]] = field( - default_factory=dict, + default=None, metadata={ "help": ( "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" @@ -1262,7 +1262,7 @@ class TrainingArguments: }, ) gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field( - default_factory=dict, + default=None, metadata={ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." }, @@ -1391,10 +1391,15 @@ class TrainingArguments: def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in VALID_DICT_FIELDS: + passed_value = getattr(self, field) # We only want to do this if the str starts with a bracket to indiciate a `dict` # else its likely a filename if supported - if isinstance(getattr(self, field), str) and getattr(self, field).startswith("{"): - setattr(self, field, json.loads(getattr(self, field))) + if isinstance(passed_value, str) and passed_value.startswith("{"): + setattr(self, field, json.loads(passed_value)) + # Since we default to a blank dict, set it to `None` after parsing + elif isinstance(passed_value, dict): + if passed_value == {}: + setattr(self, field, None) # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home From 16c3bf716be7a9e5de55911ae8913853315fb7e2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 14:27:22 -0400 Subject: [PATCH 05/10] Fin! --- src/transformers/training_args.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c4577d14f6d8..e203d6a3ff64 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -175,7 +175,8 @@ class OptimizerNames(ExplicitEnum): # Sometimes users will pass in a `str` repr of a dict in the CLI # We need to track what fields those can be. Each time a new arg -# has a dict type, it must be added to this list +# has a dict type, it must be added to this list. +# Important: These should be typed with Optional[Union[dict,str,...]] VALID_DICT_FIELDS = [ "accelerator_config", "fsdp_config", @@ -1148,7 +1149,7 @@ class TrainingArguments: ) }, ) - accelerator_config: Optional[Union[AcceleratorConfig, dict, str]] = field( + accelerator_config: Optional[Union[dict, str, AcceleratorConfig]] = field( default=None, metadata={ "help": ( From a483b4bab2da5dd8b095ff20906ff9cc35aa6e1e Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 14:32:12 -0400 Subject: [PATCH 06/10] rm cast, could keep None --- src/transformers/training_args.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e203d6a3ff64..adc5bae05e65 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1397,10 +1397,6 @@ def __post_init__(self): # else its likely a filename if supported if isinstance(passed_value, str) and passed_value.startswith("{"): setattr(self, field, json.loads(passed_value)) - # Since we default to a blank dict, set it to `None` after parsing - elif isinstance(passed_value, dict): - if passed_value == {}: - setattr(self, field, None) # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home From 98ce2dc519dbb7ba51da268f62393df5c6c632ce Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 14:59:13 -0400 Subject: [PATCH 07/10] Fix typing issue --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index adc5bae05e65..8b11cb864777 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1149,7 +1149,7 @@ class TrainingArguments: ) }, ) - accelerator_config: Optional[Union[dict, str, AcceleratorConfig]] = field( + accelerator_config: Optional[Union[dict, str, "AcceleratorConfig"]] = field( default=None, metadata={ "help": ( From 9c69c166ce486477bb4771d5f3d238778fad3220 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 12 Apr 2024 15:22:27 -0400 Subject: [PATCH 08/10] rm typehint --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8b11cb864777..e01d0c31c49c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1149,7 +1149,7 @@ class TrainingArguments: ) }, ) - accelerator_config: Optional[Union[dict, str, "AcceleratorConfig"]] = field( + accelerator_config: Optional[Union[dict, str]] = field( default=None, metadata={ "help": ( From 1c73ab10cee47d9d917fe0c0b8d2263f3b9ced52 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 15 Apr 2024 09:03:40 -0400 Subject: [PATCH 09/10] Add test --- tests/utils/test_hf_argparser.py | 58 +++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index c0fa748cbfa4..2785a0bb617e 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -22,12 +22,13 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union, get_args, get_origin import yaml from transformers import HfArgumentParser, TrainingArguments from transformers.hf_argparser import make_choice_type_function, string_to_bool +from transformers.training_args import VALID_DICT_FIELDS # Since Python 3.10, we can use the builtin `|` operator for Union types @@ -405,3 +406,58 @@ def test_parse_yaml(self): def test_integration_training_args(self): parser = HfArgumentParser(TrainingArguments) self.assertIsNotNone(parser) + + def test_valid_dict_annotation(self): + """ + Tests to make sure that `dict` based annotations + are correctly made in the `TrainingArguments`. + + If this fails, a type annotation change is + needed on a new input + """ + base_list = VALID_DICT_FIELDS.copy() + args = TrainingArguments + + # First find any annotations that contain `dict` + fields = args.__dataclass_fields__ + + raw_dict_fields = [] + optional_dict_fields = [] + + for field in fields.values(): + # First verify raw dict + if field.type in (dict, Dict): + raw_dict_fields.append(field) + # Next check for `Union` or `Optional` + elif get_origin(field.type) == Union: + if any(arg in (dict, Dict) for arg in get_args(field.type)): + optional_dict_fields.append(field) + + # First check: anything in `raw_dict_fields` is very bad + self.assertEqual( + len(raw_dict_fields), + 0, + "Found invalid raw `dict` types in the `TrainingArgument` typings. " + "This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`", + ) + + # Next check raw annotations + for field in optional_dict_fields: + args = get_args(field.type) + # These should be returned as `dict`, `str`, ... + # we only care about the first two + self.assertIn(args[0], (Dict, dict)) + self.assertEqual( + str(args[1]), + "", + f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, " + "but `str` not found. Please fix this.", + ) + + # Second check: anything in `optional_dict_fields` is bad if it's not in `base_list` + for field in optional_dict_fields: + self.assertIn( + field.name, + base_list, + f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args.VALID_DICT_FIELDS`", + ) From f81dcf00b50974550ee660ea411330b01dc31e2c Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 15 Apr 2024 15:50:19 -0400 Subject: [PATCH 10/10] Add tests and make more rigid --- src/transformers/training_args.py | 27 ++++++++++++++++++++++++--- tests/utils/test_hf_argparser.py | 17 ++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e01d0c31c49c..e5ac449c6556 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -177,7 +177,7 @@ class OptimizerNames(ExplicitEnum): # We need to track what fields those can be. Each time a new arg # has a dict type, it must be added to this list. # Important: These should be typed with Optional[Union[dict,str,...]] -VALID_DICT_FIELDS = [ +_VALID_DICT_FIELDS = [ "accelerator_config", "fsdp_config", "deepspeed", @@ -186,6 +186,24 @@ class OptimizerNames(ExplicitEnum): ] +def _convert_str_dict(passed_value: dict): + "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types." + for key, value in passed_value.items(): + if isinstance(value, dict): + passed_value[key] = _convert_str_dict(value) + elif isinstance(value, str): + # First check for bool and convert + if value.lower() in ("true", "false"): + passed_value[key] = value.lower() == "true" + # Check for digit + elif value.isdigit(): + passed_value[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + passed_value[key] = float(value) + + return passed_value + + # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @dataclass class TrainingArguments: @@ -1391,12 +1409,15 @@ class TrainingArguments: def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string - for field in VALID_DICT_FIELDS: + for field in _VALID_DICT_FIELDS: passed_value = getattr(self, field) # We only want to do this if the str starts with a bracket to indiciate a `dict` # else its likely a filename if supported if isinstance(passed_value, str) and passed_value.startswith("{"): - setattr(self, field, json.loads(passed_value)) + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, field, loaded_dict) # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 2785a0bb617e..87d1858cc6b7 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -28,7 +28,8 @@ from transformers import HfArgumentParser, TrainingArguments from transformers.hf_argparser import make_choice_type_function, string_to_bool -from transformers.training_args import VALID_DICT_FIELDS +from transformers.testing_utils import require_torch +from transformers.training_args import _VALID_DICT_FIELDS # Since Python 3.10, we can use the builtin `|` operator for Union types @@ -415,7 +416,7 @@ def test_valid_dict_annotation(self): If this fails, a type annotation change is needed on a new input """ - base_list = VALID_DICT_FIELDS.copy() + base_list = _VALID_DICT_FIELDS.copy() args = TrainingArguments # First find any annotations that contain `dict` @@ -459,5 +460,15 @@ def test_valid_dict_annotation(self): self.assertIn( field.name, base_list, - f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args.VALID_DICT_FIELDS`", + f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`", ) + + @require_torch + def test_valid_dict_input_parsing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments( + output_dir=tmp_dir, + accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}', + ) + self.assertEqual(args.accelerator_config.split_batches, True) + self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)