From 69e2d674f8f8d6c00561a16b369a4a4de3ad583c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 18 Oct 2023 10:47:52 +0300 Subject: [PATCH 01/11] Convert from recipe --- src/super_gradients/convert_from_recipe.py | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/super_gradients/convert_from_recipe.py diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py new file mode 100644 index 0000000000..dac73184e1 --- /dev/null +++ b/src/super_gradients/convert_from_recipe.py @@ -0,0 +1,32 @@ +""" +Entry point for converting recipe file to self-contained train.py file. + +General use: python -m super_gradients.convert_from_recipe --config-name="DESIRED_RECIPE". +For recipe's specific instructions and details refer to the recipe's configuration file in the recipes directory. +""" + +import hydra +from omegaconf import DictConfig + +from super_gradients import init_trainer + + +def convert_from_recipe(cfg: DictConfig, output_script_path: str): + content = [] + + with open(output_script_path, "w") as f: + f.writelines(content) + + +@hydra.main(config_path="recipes", version_base="1.2") +def _main(cfg: DictConfig) -> None: + return convert_from_recipe(cfg, output_script_path="exported_train.py") + + +def main() -> None: + init_trainer() # `init_trainer` needs to be called before `@hydra.main` + _main() + + +if __name__ == "__main__": + main() From 4a144a73e93e2e5f72a031d3f5b1d8db7d4f8898 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 24 Oct 2023 14:06:28 +0300 Subject: [PATCH 02/11] Draft implementation of convert_from_recipe.py --- src/super_gradients/convert_from_recipe.py | 166 +++++++++++++++++++-- 1 file changed, 156 insertions(+), 10 deletions(-) diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py index dac73184e1..1829a9da21 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_from_recipe.py @@ -4,28 +4,174 @@ General use: python -m super_gradients.convert_from_recipe --config-name="DESIRED_RECIPE". For recipe's specific instructions and details refer to the recipe's configuration file in the recipes directory. """ +import argparse +import collections +import os.path +from typing import Tuple, Mapping, Dict +import black import hydra -from omegaconf import DictConfig +import pkg_resources +from hydra.core.global_hydra import GlobalHydra +from omegaconf import DictConfig, OmegaConf, ListConfig -from super_gradients import init_trainer +from super_gradients import Trainer +from super_gradients.common import MultiGPUMode +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers +from super_gradients.common.environment.path_utils import normalize_path +from super_gradients.training.utils import get_param + +logger = get_logger(__name__) + + +def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) -> Tuple[DictConfig, Dict[str, Mapping]]: + if objects is None: + objects = collections.OrderedDict() + if prefix is None: + prefix = "" + + if isinstance(cfg, DictConfig): + for key, value in cfg.items(): + value, objects = recursively_walk_and_extract_hydra_targets(value, objects, prefix=f"{prefix}_{key}") + cfg[key] = value + + if "_target_" in cfg: + target_class = cfg["_target_"] + target_params = dict([(k, v) for k, v in cfg.items() if k != "_target_"]) + object_name = f"{prefix}".replace(".", "_").lower() + objects[object_name] = (target_class, target_params) + cfg = object_name + + elif isinstance(cfg, ListConfig): + for index, item in enumerate(cfg): + item, objects = recursively_walk_and_extract_hydra_targets(item, objects, prefix=f"{prefix}_{index}") + cfg[index] = item + else: + print(f"Skipping {cfg}") + pass + return cfg, objects def convert_from_recipe(cfg: DictConfig, output_script_path: str): - content = [] + cfg = Trainer._trigger_cfg_modifying_callbacks(cfg) - with open(output_script_path, "w") as f: - f.writelines(content) + device = get_param(cfg, "device") + multi_gpu = get_param(cfg, "multi_gpu") + + if multi_gpu is False: + multi_gpu = MultiGPUMode.OFF + num_gpus = get_param(cfg, "num_gpus") + print(device, multi_gpu, num_gpus) + + train_dataloader = get_param(cfg, "train_dataloader") + train_dataset_params = OmegaConf.to_container(cfg.dataset_params.train_dataset_params, resolve=True) + train_dataloader_params = OmegaConf.to_container(cfg.dataset_params.train_dataloader_params, resolve=True) + + val_dataloader = get_param(cfg, "val_dataloader") + val_dataset_params = OmegaConf.to_container(cfg.dataset_params.val_dataset_params, resolve=True) + val_dataloader_params = OmegaConf.to_container(cfg.dataset_params.val_dataloader_params, resolve=True) + + num_classes = cfg.arch_params.num_classes + arch_params = OmegaConf.to_container(cfg.arch_params, resolve=True) + + training_hyperparams, hydra_instantiated_objects = recursively_walk_and_extract_hydra_targets(cfg.training_hyperparams) + print(hydra_instantiated_objects) + + checkpoint_num_classes = get_param(cfg.checkpoint_params, "checkpoint_num_classes") + content = f""" +import super_gradients +from super_gradients import init_trainer, Trainer +from super_gradients.training.utils.distributed_training_utils import setup_device +from super_gradients.training import models, dataloaders +from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad +import numpy as np + +def main(): + init_trainer() + setup_device(device={device}, multi_gpu="{multi_gpu}", num_gpus={num_gpus}) + + trainer = Trainer(experiment_name="{cfg.experiment_name}", ckpt_root_dir="{cfg.ckpt_root_dir}") + + num_classes = {num_classes} + arch_params = {arch_params} + model = models.get( + model_name="{cfg.architecture}", + num_classes=num_classes, + arch_params=arch_params, + strict_load={hydra.utils.instantiate(cfg.checkpoint_params.strict_load)}, + pretrained_weights={cfg.checkpoint_params.pretrained_weights}, + checkpoint_path={cfg.checkpoint_params.checkpoint_path}, + load_backbone={cfg.checkpoint_params.load_backbone}, + checkpoint_num_classes={checkpoint_num_classes}, + ) -@hydra.main(config_path="recipes", version_base="1.2") -def _main(cfg: DictConfig) -> None: - return convert_from_recipe(cfg, output_script_path="exported_train.py") + train_dataloader = dataloaders.get( + name="{train_dataloader}", + dataset_params={train_dataset_params}, + dataloader_params={train_dataloader_params}, + ) + + val_dataloader = dataloaders.get( + name="{val_dataloader}", + dataset_params={val_dataset_params}, + dataloader_params={val_dataloader_params}, + ) + +""" + for name, (class_name, class_params) in hydra_instantiated_objects.items(): + class_params_str = [] + for k, v in class_params.items(): + class_params_str.append(f"{k}={v}") + class_params_str = ",".join(class_params_str) + content += f" {name} = {class_name}({class_params_str})\n\n" + + content += f""" + + training_hyperparams = {training_hyperparams} + + # TRAIN + result = trainer.train( + model=model, + train_loader=train_dataloader, + valid_loader=val_dataloader, + training_params=training_hyperparams, + ) + + print(result) + +if __name__ == "__main__": + main() +""" + # Remove quotes from dict values to reference them as variables + for key in hydra_instantiated_objects.keys(): + key_to_search = f"'{key}'" + key_to_replace_with = f"{key}" + print(key_to_search, key_to_replace_with, key_to_search in content) + content = content.replace(key_to_search, key_to_replace_with) + + with open(output_script_path, "w") as f: + content = black.format_str(content, mode=black.FileMode(line_length=120)) + f.write(content) def main() -> None: - init_trainer() # `init_trainer` needs to be called before `@hydra.main` - _main() + parser = argparse.ArgumentParser() + parser.add_argument("config_name", type=str, help=".yaml filename") + parser.add_argument("--config_dir", type=str, default=pkg_resources.resource_filename("super_gradients.recipes", ""), help="The config directory path") + parser.add_argument("--save_path", type=str, default=None, help="Destination path to the output .py file") + args = parser.parse_args() + + save_path = args.save_path or os.path.splitext(os.path.basename(args.config_name))[0] + ".py" + logger.info(f"Saving recipe script to {save_path}") + + register_hydra_resolvers() + GlobalHydra.instance().clear() + with hydra.initialize_config_dir(config_dir=normalize_path(args.config_dir), version_base="1.2"): + cfg = hydra.compose(config_name=args.config_name) + + convert_from_recipe(cfg, save_path) if __name__ == "__main__": From 0ff6c310727df011a72f1e28dcc5d8061accfa4d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 24 Oct 2023 14:21:41 +0300 Subject: [PATCH 03/11] Added global config resolve call --- src/super_gradients/convert_from_recipe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py index 1829a9da21..a3847b4061 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_from_recipe.py @@ -55,6 +55,7 @@ def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) - def convert_from_recipe(cfg: DictConfig, output_script_path: str): cfg = Trainer._trigger_cfg_modifying_callbacks(cfg) + OmegaConf.resolve(cfg) device = get_param(cfg, "device") multi_gpu = get_param(cfg, "multi_gpu") From f94c9e3191d4dca3705dda3fb4f834dbb4004556 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 24 Oct 2023 14:35:29 +0300 Subject: [PATCH 04/11] Install black on demand --- src/super_gradients/convert_from_recipe.py | 29 +++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py index a3847b4061..52a8b1e8ff 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_from_recipe.py @@ -9,7 +9,6 @@ import os.path from typing import Tuple, Mapping, Dict -import black import hydra import pkg_resources from hydra.core.global_hydra import GlobalHydra @@ -25,6 +24,34 @@ logger = get_logger(__name__) +def import_black_or_fail_with_instructions(): + try: + import black + + return black + except ImportError: + raise ImportError( + "Black is not installed. Please install it with `pip install black` and try again. " + "If you are using a virtual environment, make sure it is activated." + ) + + +def import_black_or_install(): + try: + import black + + return black + except ImportError: + import pip + + pip.main(["install", "black==22.10.0"]) + + return import_black_or_fail_with_instructions() + + +black = import_black_or_install() + + def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) -> Tuple[DictConfig, Dict[str, Mapping]]: if objects is None: objects = collections.OrderedDict() From 420aef403ba86d0ad36bfb8ec152d1b718844dce Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 25 Oct 2023 11:46:39 +0300 Subject: [PATCH 05/11] Added unit test to test all recipes --- src/super_gradients/convert_from_recipe.py | 31 +++--- tests/deci_core_unit_test_suite_runner.py | 2 + tests/unit_tests/convert_from_recipe_tests.py | 100 ++++++++++++++++++ 3 files changed, 120 insertions(+), 13 deletions(-) create mode 100644 tests/unit_tests/convert_from_recipe_tests.py diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py index 52a8b1e8ff..c0d97819b1 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_from_recipe.py @@ -7,7 +7,8 @@ import argparse import collections import os.path -from typing import Tuple, Mapping, Dict +import pathlib +from typing import Tuple, Mapping, Dict, Union import hydra import pkg_resources @@ -75,12 +76,20 @@ def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) - item, objects = recursively_walk_and_extract_hydra_targets(item, objects, prefix=f"{prefix}_{index}") cfg[index] = item else: - print(f"Skipping {cfg}") pass return cfg, objects -def convert_from_recipe(cfg: DictConfig, output_script_path: str): +def convert_from_recipe(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]): + config_name = str(config_name) + config_dir = str(config_dir) + output_script_path = str(output_script_path) + + register_hydra_resolvers() + GlobalHydra.instance().clear() + with hydra.initialize_config_dir(config_dir=normalize_path(config_dir), version_base="1.2"): + cfg = hydra.compose(config_name=config_name) + cfg = Trainer._trigger_cfg_modifying_callbacks(cfg) OmegaConf.resolve(cfg) @@ -90,7 +99,6 @@ def convert_from_recipe(cfg: DictConfig, output_script_path: str): if multi_gpu is False: multi_gpu = MultiGPUMode.OFF num_gpus = get_param(cfg, "num_gpus") - print(device, multi_gpu, num_gpus) train_dataloader = get_param(cfg, "train_dataloader") train_dataset_params = OmegaConf.to_container(cfg.dataset_params.train_dataset_params, resolve=True) @@ -103,8 +111,11 @@ def convert_from_recipe(cfg: DictConfig, output_script_path: str): num_classes = cfg.arch_params.num_classes arch_params = OmegaConf.to_container(cfg.arch_params, resolve=True) + strict_load = cfg.checkpoint_params.strict_load + if isinstance(strict_load, Mapping) and "_target_" in strict_load: + strict_load = hydra.utils.instantiate(strict_load) + training_hyperparams, hydra_instantiated_objects = recursively_walk_and_extract_hydra_targets(cfg.training_hyperparams) - print(hydra_instantiated_objects) checkpoint_num_classes = get_param(cfg.checkpoint_params, "checkpoint_num_classes") content = f""" @@ -128,7 +139,7 @@ def main(): model_name="{cfg.architecture}", num_classes=num_classes, arch_params=arch_params, - strict_load={hydra.utils.instantiate(cfg.checkpoint_params.strict_load)}, + strict_load={strict_load}, pretrained_weights={cfg.checkpoint_params.pretrained_weights}, checkpoint_path={cfg.checkpoint_params.checkpoint_path}, load_backbone={cfg.checkpoint_params.load_backbone}, @@ -176,7 +187,6 @@ def main(): for key in hydra_instantiated_objects.keys(): key_to_search = f"'{key}'" key_to_replace_with = f"{key}" - print(key_to_search, key_to_replace_with, key_to_search in content) content = content.replace(key_to_search, key_to_replace_with) with open(output_script_path, "w") as f: @@ -194,12 +204,7 @@ def main() -> None: save_path = args.save_path or os.path.splitext(os.path.basename(args.config_name))[0] + ".py" logger.info(f"Saving recipe script to {save_path}") - register_hydra_resolvers() - GlobalHydra.instance().clear() - with hydra.initialize_config_dir(config_dir=normalize_path(args.config_dir), version_base="1.2"): - cfg = hydra.compose(config_name=args.config_name) - - convert_from_recipe(cfg, save_path) + convert_from_recipe(args.config_name, args.config_dir, save_path) if __name__ == "__main__": diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index b9abce904e..94df1a1484 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -26,6 +26,7 @@ TestDeprecationDecorator, ) from tests.end_to_end_tests import TestTrainer +from tests.unit_tests.convert_from_recipe_tests import TestConvertFromRecipe from tests.unit_tests.detection_utils_test import TestDetectionUtils from tests.unit_tests.detection_dataset_test import DetectionDatasetTest, TestParseYoloLabelFile from tests.unit_tests.export_detection_model_test import TestDetectionModelExport @@ -162,6 +163,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationModelExport)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASPoseTests)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertFromRecipe)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/convert_from_recipe_tests.py b/tests/unit_tests/convert_from_recipe_tests.py new file mode 100644 index 0000000000..7479e613e0 --- /dev/null +++ b/tests/unit_tests/convert_from_recipe_tests.py @@ -0,0 +1,100 @@ +import ast +import tempfile + +import pkg_resources +import unittest + +from super_gradients.convert_from_recipe import convert_from_recipe +from pathlib import Path + + +class TestConvertFromRecipe(unittest.TestCase): + def setUp(self) -> None: + self.recipes_dir: Path = Path(pkg_resources.resource_filename("super_gradients.recipes", "")) + self.recipes_that_should_work = [ + "cifar10_resnet.yaml", + "cityscapes_al_ddrnet.yaml", + "cityscapes_ddrnet.yaml", + "cityscapes_pplite_seg50.yaml", + "cityscapes_pplite_seg75.yaml", + "cityscapes_regseg48.yaml", + "cityscapes_segformer_b0.yaml", + "cityscapes_segformer_b1.yaml", + "cityscapes_segformer_b2.yaml", + "cityscapes_segformer_b3.yaml", + "cityscapes_segformer_b4.yaml", + "cityscapes_segformer_b5.yaml", + "cityscapes_stdc_base.yaml", + "cityscapes_stdc_seg50.yaml", + "cityscapes_stdc_seg75.yaml", + "coco2017_pose_dekr_rescoring.yaml", + "coco2017_pose_dekr_w32_no_dc.yaml", + "coco2017_ppyoloe_l.yaml", + "coco2017_ppyoloe_m.yaml", + "coco2017_ppyoloe_s.yaml", + "coco2017_ppyoloe_x.yaml", + "coco2017_ssd_lite_mobilenet_v2.yaml", + "coco2017_yolo_nas_s.yaml", + "coco2017_yolox.yaml", + "coco_segmentation_shelfnet_lw.yaml", + "imagenet_efficientnet.yaml", + "imagenet_mobilenetv2.yaml", + "imagenet_mobilenetv3_large.yaml", + "imagenet_mobilenetv3_small.yaml", + "imagenet_regnetY.yaml", + "imagenet_repvgg.yaml", + "imagenet_resnet50.yaml", + "imagenet_vit_base.yaml", + "imagenet_vit_large.yaml", + "supervisely_unet.yaml", + "user_recipe_mnist_as_external_dataset_example.yaml", + "user_recipe_mnist_example.yaml", + ] + + self.recipes_that_does_not_work = [ + "cityscapes_kd_base.yaml", # KD recipe not supported + "imagenet_resnet50_kd.yaml", # KD recipe not supported + "imagenet_mobilenetv3_base.yaml", # Base recipe (not complete) for other MobileNetV3 recipes + "cityscapes_segformer.yaml", # Base recipe (not complete) for other SegFormer recipes + "roboflow_ppyoloe.yaml", # Require explicit command line arguments + "roboflow_yolo_nas_m.yaml", # Require explicit command line arguments + "roboflow_yolo_nas_s.yaml", # Require explicit command line arguments + "roboflow_yolo_nas_s_qat.yaml", # Require explicit command line arguments + "roboflow_yolox.yaml", # Require explicit command line arguments + "variable_setup.yaml", # Not a recipe + "script_generate_rescoring_data_dekr_coco2017.yaml", # Not a recipe + ] + + def test_all_recipes_are_tested(self): + present_recipes = set(recipe.name for recipe in self.recipes_dir.glob("*.yaml")) + known_recipes = set(self.recipes_that_should_work + self.recipes_that_does_not_work) + new_recipes = present_recipes - known_recipes + removed_recipes = known_recipes - present_recipes + if len(new_recipes): + self.fail(f"New recipes found: {new_recipes}. Please add them to the list of recipes to test.") + if len(removed_recipes): + self.fail(f"Removed recipes found: {removed_recipes}. Please remove them from the list of recipes to test.") + + def test_convert_recipes_that_should_work(self): + with tempfile.TemporaryDirectory() as temp_dir: + for recipe in self.recipes_that_should_work: + with self.subTest(recipe=recipe): + output_script_path = Path(temp_dir) / Path(recipe).name + convert_from_recipe(recipe, self.recipes_dir, output_script_path) + src = output_script_path.read_text() + try: + ast.parse(src, feature_version=(3, 9)) + except SyntaxError as e: + self.fail(f"Recipe {recipe} failed to convert to python script: {e}") + + def test_convert_recipes_that_are_expected_to_fail(self): + with tempfile.TemporaryDirectory() as temp_dir: + for recipe in self.recipes_that_does_not_work: + with self.subTest(recipe=recipe): + output_script_path = Path(temp_dir) / Path(recipe).name + with self.assertRaises(Exception): + convert_from_recipe(recipe, self.recipes_dir, output_script_path) + + +if __name__ == "__main__": + unittest.main() From 2ede874d8ebc0edb12380521c068447a8e86a361 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 25 Oct 2023 13:57:42 +0300 Subject: [PATCH 06/11] Make black optional --- src/super_gradients/convert_from_recipe.py | 88 ++++++++++++++-------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_from_recipe.py index c0d97819b1..034f251d0d 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_from_recipe.py @@ -1,14 +1,21 @@ """ Entry point for converting recipe file to self-contained train.py file. -General use: python -m super_gradients.convert_from_recipe --config-name="DESIRED_RECIPE". -For recipe's specific instructions and details refer to the recipe's configuration file in the recipes directory. +Convert a recipe YAML file to a self-contained file that can be run with python . +Generated file will contain all training hyperparameters from input recipe file but will be self-contained (no dependencies on original recipe). + +Limitations: Converting a recipe with command-line overrides of some parameters in this recipe is not supported. + +General use: python -m super_gradients.convert_from_recipe DESIRED_RECIPE OUTPUT_SCRIPT +Example: python -m super_gradients.convert_from_recipe coco2017_yolo_nas_s train_coco2017_yolo_nas_s.py + +For recipe's specific instructions and details refer to the recipe's configuration file in the recipes' directory. """ import argparse import collections import os.path import pathlib -from typing import Tuple, Mapping, Dict, Union +from typing import Tuple, Mapping, Dict, Union, Optional import hydra import pkg_resources @@ -25,35 +32,43 @@ logger = get_logger(__name__) -def import_black_or_fail_with_instructions(): - try: - import black - - return black - except ImportError: - raise ImportError( - "Black is not installed. Please install it with `pip install black` and try again. " - "If you are using a virtual environment, make sure it is activated." - ) - - -def import_black_or_install(): +def try_import_black(): + """ + Attempts to import black code formatter. + If black is not installed, it will attempt to install it with pip. + If installation fails, it will return None + """ try: import black return black except ImportError: - import pip - - pip.main(["install", "black==22.10.0"]) - - return import_black_or_fail_with_instructions() - - -black = import_black_or_install() - - -def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) -> Tuple[DictConfig, Dict[str, Mapping]]: + logger.info("Trying to install black using pip to enable formatting of the generated script.") + try: + import pip + + pip.main(["install", "black==22.10.0"]) + import black + + logger.info("Black installed via pip. ") + return black + except Exception: + logger.info("Black installation failed. Formatting of the generated script will be disabled.") + return None + + +def recursively_walk_and_extract_hydra_targets( + cfg: DictConfig, objects: Optional[Mapping] = None, prefix: Optional[str] = None +) -> Tuple[DictConfig, Dict[str, Mapping]]: + """ + Iterates over the input config, extracts all hydra targets present in it and replace them with variable references. + Extracted hydra targets are stored in the objects dictionary (Used to generated instantiations of the objects in the generated script). + + :param cfg: Input config + :param objects: Dictionary of extracted hydra targets + :param prefix: A prefix variable to track the path to the current config (Used to give variables meaningful name) + :return: A new config and the dictionary of objects that must be created in the generated script + """ if objects is None: objects = collections.OrderedDict() if prefix is None: @@ -80,7 +95,18 @@ def recursively_walk_and_extract_hydra_targets(cfg, objects=None, prefix=None) - return cfg, objects -def convert_from_recipe(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]): +def convert_from_recipe(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]) -> None: + """ + Convert a recipe YAML file to a self-contained file that can be run with python . + Generated file will contain all training hyperparameters from input recipe file but will be self-contained (no dependencies on original recipe). + + Limitations: Converting a recipe with command-line overrides of some paramters in this recipe is not supported. + + :param config_name: Name of the recipe file (can be with or without .yaml extension) + :param config_dir: Directory where the recipe file is located + :param output_script_path: Path to the output .py file + :return: None + """ config_name = str(config_name) config_dir = str(config_dir) output_script_path = str(output_script_path) @@ -190,15 +216,17 @@ def main(): content = content.replace(key_to_search, key_to_replace_with) with open(output_script_path, "w") as f: - content = black.format_str(content, mode=black.FileMode(line_length=120)) + black = try_import_black() + if black is not None: + content = black.format_str(content, mode=black.FileMode(line_length=160)) f.write(content) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("config_name", type=str, help=".yaml filename") + parser.add_argument("save_path", type=str, default=None, help="Destination path to the output .py file") parser.add_argument("--config_dir", type=str, default=pkg_resources.resource_filename("super_gradients.recipes", ""), help="The config directory path") - parser.add_argument("--save_path", type=str, default=None, help="Destination path to the output .py file") args = parser.parse_args() save_path = args.save_path or os.path.splitext(os.path.basename(args.config_name))[0] + ".py" From bff04ed98f9d10ef51681f171a01d9aa9753e5ac Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 30 Oct 2023 09:21:40 +0200 Subject: [PATCH 07/11] Rename script to convert_recipe_to_code --- .../{convert_from_recipe.py => convert_recipe_to_code.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename src/super_gradients/{convert_from_recipe.py => convert_recipe_to_code.py} (97%) diff --git a/src/super_gradients/convert_from_recipe.py b/src/super_gradients/convert_recipe_to_code.py similarity index 97% rename from src/super_gradients/convert_from_recipe.py rename to src/super_gradients/convert_recipe_to_code.py index 034f251d0d..b4f1e35978 100644 --- a/src/super_gradients/convert_from_recipe.py +++ b/src/super_gradients/convert_recipe_to_code.py @@ -6,8 +6,8 @@ Limitations: Converting a recipe with command-line overrides of some parameters in this recipe is not supported. -General use: python -m super_gradients.convert_from_recipe DESIRED_RECIPE OUTPUT_SCRIPT -Example: python -m super_gradients.convert_from_recipe coco2017_yolo_nas_s train_coco2017_yolo_nas_s.py +General use: python -m super_gradients.convert_recipe_to_code DESIRED_RECIPE OUTPUT_SCRIPT +Example: python -m super_gradients.convert_recipe_to_code coco2017_yolo_nas_s train_coco2017_yolo_nas_s.py For recipe's specific instructions and details refer to the recipe's configuration file in the recipes' directory. """ From 5b8089d9a8819c45de5673ca9806833ebca06a8c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 30 Oct 2023 09:24:51 +0200 Subject: [PATCH 08/11] Fixed case when train_dataloader/val_dataloader is None --- src/super_gradients/convert_recipe_to_code.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/convert_recipe_to_code.py b/src/super_gradients/convert_recipe_to_code.py index b4f1e35978..4612db3fae 100644 --- a/src/super_gradients/convert_recipe_to_code.py +++ b/src/super_gradients/convert_recipe_to_code.py @@ -173,13 +173,13 @@ def main(): ) train_dataloader = dataloaders.get( - name="{train_dataloader}", + name={train_dataloader}, dataset_params={train_dataset_params}, dataloader_params={train_dataloader_params}, ) val_dataloader = dataloaders.get( - name="{val_dataloader}", + name={val_dataloader}, dataset_params={val_dataset_params}, dataloader_params={val_dataloader_params}, ) From da4f9f580e7f36a26c14cb41df28ef4def1ebe7b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 30 Oct 2023 10:02:51 +0200 Subject: [PATCH 09/11] Added step to export recipe to code and run it --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index c886288eb1..a8ad71e8fe 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -641,6 +641,7 @@ jobs: python3.8 src/super_gradients/train_from_recipe.py --config-name=imagenet_resnet50 batch_size=8 val_batch_size=16 epochs=1 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=100 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 dataset_params.train_dataset_params.root=/data/Imagenet/train dataset_params.val_dataset_params.root=/data/Imagenet/val python3.8 src/super_gradients/train_from_recipe.py --config-name=imagenet_vit_base batch_size=8 val_batch_size=16 epochs=1 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=100 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 dataset_params.train_dataset_params.root=/data/Imagenet/train dataset_params.val_dataset_params.root=/data/Imagenet/val python3.8 src/super_gradients/train_from_kd_recipe.py --config-name=imagenet_resnet50_kd batch_size=8 val_batch_size=8 epochs=1 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=100 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 dataset_params.train_dataset_params.root=/data/Imagenet/train dataset_params.val_dataset_params.root=/data/Imagenet/val + python3.8 src/super_gradients/convert_recipe_to_code.py cifar10_resnet.yaml train_cifar10_resnet.py && python3.8 train_cifar10_resnet.py - run: name: Remove new environment when failed From 5f45aaa0a3984e9c60a94e53f30f06101ab8331f Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 30 Oct 2023 10:37:15 +0200 Subject: [PATCH 10/11] Rename import from test --- src/super_gradients/convert_recipe_to_code.py | 4 ++-- tests/unit_tests/convert_from_recipe_tests.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/super_gradients/convert_recipe_to_code.py b/src/super_gradients/convert_recipe_to_code.py index 4612db3fae..77baab3f34 100644 --- a/src/super_gradients/convert_recipe_to_code.py +++ b/src/super_gradients/convert_recipe_to_code.py @@ -95,7 +95,7 @@ def recursively_walk_and_extract_hydra_targets( return cfg, objects -def convert_from_recipe(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]) -> None: +def convert_recipe_to_code(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]) -> None: """ Convert a recipe YAML file to a self-contained file that can be run with python . Generated file will contain all training hyperparameters from input recipe file but will be self-contained (no dependencies on original recipe). @@ -232,7 +232,7 @@ def main() -> None: save_path = args.save_path or os.path.splitext(os.path.basename(args.config_name))[0] + ".py" logger.info(f"Saving recipe script to {save_path}") - convert_from_recipe(args.config_name, args.config_dir, save_path) + convert_recipe_to_code(args.config_name, args.config_dir, save_path) if __name__ == "__main__": diff --git a/tests/unit_tests/convert_from_recipe_tests.py b/tests/unit_tests/convert_from_recipe_tests.py index 7479e613e0..73506b96b2 100644 --- a/tests/unit_tests/convert_from_recipe_tests.py +++ b/tests/unit_tests/convert_from_recipe_tests.py @@ -4,7 +4,7 @@ import pkg_resources import unittest -from super_gradients.convert_from_recipe import convert_from_recipe +from super_gradients.convert_recipe_to_code import convert_recipe_to_code from pathlib import Path @@ -80,7 +80,7 @@ def test_convert_recipes_that_should_work(self): for recipe in self.recipes_that_should_work: with self.subTest(recipe=recipe): output_script_path = Path(temp_dir) / Path(recipe).name - convert_from_recipe(recipe, self.recipes_dir, output_script_path) + convert_recipe_to_code(recipe, self.recipes_dir, output_script_path) src = output_script_path.read_text() try: ast.parse(src, feature_version=(3, 9)) @@ -93,7 +93,7 @@ def test_convert_recipes_that_are_expected_to_fail(self): with self.subTest(recipe=recipe): output_script_path = Path(temp_dir) / Path(recipe).name with self.assertRaises(Exception): - convert_from_recipe(recipe, self.recipes_dir, output_script_path) + convert_recipe_to_code(recipe, self.recipes_dir, output_script_path) if __name__ == "__main__": From 17acd2c2339b280bc90da8e7c728c6ad301debbf Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 30 Oct 2023 10:38:17 +0200 Subject: [PATCH 11/11] Rename tests to match the naming convention --- tests/deci_core_unit_test_suite_runner.py | 4 ++-- ...rt_from_recipe_tests.py => test_convert_recipe_to_code.py} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename tests/unit_tests/{convert_from_recipe_tests.py => test_convert_recipe_to_code.py} (98%) diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 0618a6e00b..c769271f4f 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -28,7 +28,7 @@ TestMixedPrecisionDisabled, ) from tests.end_to_end_tests import TestTrainer -from tests.unit_tests.convert_from_recipe_tests import TestConvertFromRecipe +from tests.unit_tests.test_convert_recipe_to_code import TestConvertRecipeToCode from tests.unit_tests.detection_utils_test import TestDetectionUtils from tests.unit_tests.detection_dataset_test import DetectionDatasetTest, TestParseYoloLabelFile from tests.unit_tests.export_detection_model_test import TestDetectionModelExport @@ -167,7 +167,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMixedPrecisionDisabled)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DynamicModelTests)) - self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertFromRecipe)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertRecipeToCode)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/convert_from_recipe_tests.py b/tests/unit_tests/test_convert_recipe_to_code.py similarity index 98% rename from tests/unit_tests/convert_from_recipe_tests.py rename to tests/unit_tests/test_convert_recipe_to_code.py index 73506b96b2..33bdb37922 100644 --- a/tests/unit_tests/convert_from_recipe_tests.py +++ b/tests/unit_tests/test_convert_recipe_to_code.py @@ -8,7 +8,7 @@ from pathlib import Path -class TestConvertFromRecipe(unittest.TestCase): +class TestConvertRecipeToCode(unittest.TestCase): def setUp(self) -> None: self.recipes_dir: Path = Path(pkg_resources.resource_filename("super_gradients.recipes", "")) self.recipes_that_should_work = [