Skip to content

Commit

Permalink
Fixed ListConfig in pose estimation dataset classes (#1602)
Browse files Browse the repository at this point in the history
* Fixed ListConfig in dataset class

* Ensure that config is validated before saving it

* Fixed _validate_checkpoint

* Fixed _validate_checkpoint

* Rename _validate_checkpoint to _sanitimize_checkpoint that now converts ListConfig and DictConfig to python primitive containers

---------

Co-authored-by: Ofri Masad <[email protected]>
  • Loading branch information
BloodAxe and ofrimasad authored Nov 6, 2023
1 parent 29dea7a commit bcc026c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 8 deletions.
33 changes: 31 additions & 2 deletions src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import json
import os
import shutil
import signal
import time
from typing import Union, Any
Expand All @@ -9,21 +11,21 @@
import psutil
import torch
from PIL import Image
import shutil
from omegaconf import ListConfig, DictConfig, OmegaConf

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.decorators.code_save_decorator import saved_codes
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit
from super_gradients.training.params import TrainingParams
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir

logger = get_logger(__name__)

Expand Down Expand Up @@ -312,6 +314,7 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None) ->
name += ".pth"
path = os.path.join(self._local_dir, name)

state_dict = self._sanitize_checkpoint(state_dict)
self._save_checkpoint(path=path, state_dict=state_dict)

@multi_process_safe
Expand Down Expand Up @@ -348,3 +351,29 @@ def _save_code(self):
self.add_file(name)
code = "\t" + code
self.add_text(name, code.replace("\n", " \n \t")) # this replacement makes tb format the code as code

def _sanitize_checkpoint(self, state_dict: dict) -> dict:
"""
Sanitize state dictionary to be saved in a checkpoint. Iterates recursively over the state_dict and converts
all instances of ListConfig and DictConfig to their native python counterparts.
:param state_dict: Checkpoint state_dict.
:return: Sanitized checkpoint state_dict.
"""
if isinstance(state_dict, (ListConfig, DictConfig)):
state_dict = OmegaConf.to_container(state_dict, resolve=True)

if isinstance(state_dict, torch.Tensor):
pass
elif isinstance(state_dict, collections.OrderedDict):
state_dict = collections.OrderedDict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, dict):
state_dict = dict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, list):
state_dict = [self._sanitize_checkpoint(v) for v in state_dict]
elif isinstance(state_dict, tuple):
state_dict = tuple(self._sanitize_checkpoint(v) for v in state_dict)
else:
pass

return state_dict
2 changes: 2 additions & 0 deletions src/super_gradients/common/sg_loggers/clearml_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)

name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _save_wandb_artifact(self, path):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List, Union

import numpy as np
from omegaconf import ListConfig
from torch.utils.data.dataloader import Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand Down Expand Up @@ -32,9 +33,9 @@ def __init__(
self,
transforms: List[AbstractKeypointTransform],
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""
Expand All @@ -50,6 +51,18 @@ def __init__(
load_sample_fn=self.load_random_sample,
)
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from omegaconf import ListConfig
from torch.utils.data.dataloader import default_collate, Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand All @@ -28,9 +29,9 @@ def __init__(
transforms: List[KeypointTransform],
min_instance_area: float,
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""
Expand All @@ -48,6 +49,18 @@ def __init__(
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down

0 comments on commit bcc026c

Please sign in to comment.