From bcc026c9e805c227174bccca9725e9cfa240f28b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 6 Nov 2023 08:59:10 +0200 Subject: [PATCH] Fixed ListConfig in pose estimation dataset classes (#1602) * Fixed ListConfig in dataset class * Ensure that config is validated before saving it * Fixed _validate_checkpoint * Fixed _validate_checkpoint * Rename _validate_checkpoint to _sanitimize_checkpoint that now converts ListConfig and DictConfig to python primitive containers --------- Co-authored-by: Ofri Masad --- .../common/sg_loggers/base_sg_logger.py | 33 +++++++++++++++++-- .../common/sg_loggers/clearml_sg_logger.py | 2 ++ .../common/sg_loggers/dagshub_sg_logger.py | 1 + .../common/sg_loggers/wandb_sg_logger.py | 1 + .../abstract_pose_estimation_dataset.py | 19 +++++++++-- .../base_keypoints.py | 19 +++++++++-- 6 files changed, 67 insertions(+), 8 deletions(-) diff --git a/src/super_gradients/common/sg_loggers/base_sg_logger.py b/src/super_gradients/common/sg_loggers/base_sg_logger.py index e5210c528f..e5091979f4 100644 --- a/src/super_gradients/common/sg_loggers/base_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/base_sg_logger.py @@ -1,5 +1,7 @@ +import collections import json import os +import shutil import signal import time from typing import Union, Any @@ -9,13 +11,14 @@ import psutil import torch from PIL import Image -import shutil +from omegaconf import ListConfig, DictConfig, OmegaConf from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig from super_gradients.common.auto_logging.console_logging import ConsoleSink from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces from super_gradients.common.decorators.code_save_decorator import saved_codes +from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.environment.monitoring import SystemMonitor from super_gradients.common.registry.registry import register_sg_logger @@ -23,7 +26,6 @@ from super_gradients.common.sg_loggers.time_units import TimeUnit from super_gradients.training.params import TrainingParams from super_gradients.training.utils import sg_trainer_utils, get_param -from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir logger = get_logger(__name__) @@ -312,6 +314,7 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None) -> name += ".pth" path = os.path.join(self._local_dir, name) + state_dict = self._sanitize_checkpoint(state_dict) self._save_checkpoint(path=path, state_dict=state_dict) @multi_process_safe @@ -348,3 +351,29 @@ def _save_code(self): self.add_file(name) code = "\t" + code self.add_text(name, code.replace("\n", " \n \t")) # this replacement makes tb format the code as code + + def _sanitize_checkpoint(self, state_dict: dict) -> dict: + """ + Sanitize state dictionary to be saved in a checkpoint. Iterates recursively over the state_dict and converts + all instances of ListConfig and DictConfig to their native python counterparts. + + :param state_dict: Checkpoint state_dict. + :return: Sanitized checkpoint state_dict. + """ + if isinstance(state_dict, (ListConfig, DictConfig)): + state_dict = OmegaConf.to_container(state_dict, resolve=True) + + if isinstance(state_dict, torch.Tensor): + pass + elif isinstance(state_dict, collections.OrderedDict): + state_dict = collections.OrderedDict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items()) + elif isinstance(state_dict, dict): + state_dict = dict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items()) + elif isinstance(state_dict, list): + state_dict = [self._sanitize_checkpoint(v) for v in state_dict] + elif isinstance(state_dict, tuple): + state_dict = tuple(self._sanitize_checkpoint(v) for v in state_dict) + else: + pass + + return state_dict diff --git a/src/super_gradients/common/sg_loggers/clearml_sg_logger.py b/src/super_gradients/common/sg_loggers/clearml_sg_logger.py index 9c36bbfa8f..e4df40f786 100644 --- a/src/super_gradients/common/sg_loggers/clearml_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/clearml_sg_logger.py @@ -222,6 +222,8 @@ def upload(self): @multi_process_safe def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0): + state_dict = self._sanitize_checkpoint(state_dict) + name = f"ckpt_{global_step}.pth" if tag is None else tag if not name.endswith(".pth"): name += ".pth" diff --git a/src/super_gradients/common/sg_loggers/dagshub_sg_logger.py b/src/super_gradients/common/sg_loggers/dagshub_sg_logger.py index 4ddfdf6f11..8f0b35ffbb 100644 --- a/src/super_gradients/common/sg_loggers/dagshub_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/dagshub_sg_logger.py @@ -256,6 +256,7 @@ def upload(self): @multi_process_safe def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0): + state_dict = self._sanitize_checkpoint(state_dict) name = f"ckpt_{global_step}.pth" if tag is None else tag if not name.endswith(".pth"): name += ".pth" diff --git a/src/super_gradients/common/sg_loggers/wandb_sg_logger.py b/src/super_gradients/common/sg_loggers/wandb_sg_logger.py index 6cd6c8c80e..4bc5683d0b 100644 --- a/src/super_gradients/common/sg_loggers/wandb_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/wandb_sg_logger.py @@ -265,6 +265,7 @@ def _save_wandb_artifact(self, path): @multi_process_safe def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0): + state_dict = self._sanitize_checkpoint(state_dict) name = f"ckpt_{global_step}.pth" if tag is None else tag if not name.endswith(".pth"): name += ".pth" diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/abstract_pose_estimation_dataset.py b/src/super_gradients/training/datasets/pose_estimation_datasets/abstract_pose_estimation_dataset.py index ef557c639c..06d08243cd 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/abstract_pose_estimation_dataset.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/abstract_pose_estimation_dataset.py @@ -3,6 +3,7 @@ from typing import Tuple, List, Union import numpy as np +from omegaconf import ListConfig from torch.utils.data.dataloader import Dataset from super_gradients.common.abstractions.abstract_logger import get_logger @@ -32,9 +33,9 @@ def __init__( self, transforms: List[AbstractKeypointTransform], num_joints: int, - edge_links: Union[List[Tuple[int, int]], np.ndarray], - edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], - keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], + edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray], + edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None], + keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None], ): """ @@ -50,6 +51,18 @@ def __init__( load_sample_fn=self.load_random_sample, ) self.num_joints = num_joints + + # Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples + # This is necessary to ensure ListConfig objects do not leak to these properties + # and from there - to checkpoint's state_dict. + # Otherwise, through ListConfig instances a whole configuration file will leak to state_dict + # and torch.load will attempt to unpickle lot of unnecessary classes. + edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links] + if edge_colors is not None: + edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors] + if keypoint_colors is not None: + keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors] + self.edge_links = edge_links self.edge_colors = edge_colors or generate_color_mapping(len(edge_links)) self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints) diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py index 8054641428..22f6e712a0 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py @@ -3,6 +3,7 @@ import numpy as np import torch +from omegaconf import ListConfig from torch.utils.data.dataloader import default_collate, Dataset from super_gradients.common.abstractions.abstract_logger import get_logger @@ -28,9 +29,9 @@ def __init__( transforms: List[KeypointTransform], min_instance_area: float, num_joints: int, - edge_links: Union[List[Tuple[int, int]], np.ndarray], - edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], - keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], + edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray], + edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None], + keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None], ): """ @@ -48,6 +49,18 @@ def __init__( self.transforms = KeypointsCompose(transforms) self.min_instance_area = min_instance_area self.num_joints = num_joints + + # Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples + # This is necessary to ensure ListConfig objects do not leak to these properties + # and from there - to checkpoint's state_dict. + # Otherwise, through ListConfig instances a whole configuration file will leak to state_dict + # and torch.load will attempt to unpickle lot of unnecessary classes. + edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links] + if edge_colors is not None: + edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors] + if keypoint_colors is not None: + keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors] + self.edge_links = edge_links self.edge_colors = edge_colors or generate_color_mapping(len(edge_links)) self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)