Skip to content

Commit

Permalink
update hydra to 1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Oct 23, 2021
1 parent 482400e commit 1361db7
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 87 deletions.
12 changes: 7 additions & 5 deletions disent/frameworks/vae/experimental/_unsupervised__dorvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
# initialise
if self.cfg.overlap_augment_mode != 'none':
assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.'
if self.cfg.overlap_augment is not None:
# TODO: this should not reference experiments!
from experiment.util.hydra_utils import instantiate_object_if_needed
self._augment = instantiate_object_if_needed(self.cfg.overlap_augment)
# set augment and instantiate if needed
self._augment = None
if isinstance(self._augment, dict):
import hydra
self._augment = hydra.utils.instantiate(self.cfg.overlap_augment)
assert callable(self._augment), f'augment is not callable: {repr(self._augment)}'
# get overlap loss
overlap_loss = self.cfg.overlap_loss if (self.cfg.overlap_loss is not None) else self.cfg.recon_loss
Expand Down Expand Up @@ -152,7 +153,8 @@ def augment_triplet_targets(self, xs_targ):
elif (self.cfg.overlap_augment_mode == 'augment') or (self.cfg.overlap_augment_mode == 'augment_each'):
# recreate augment each time
if self.cfg.overlap_augment_mode == 'augment_each':
self._augment = instantiate_recursive(self.cfg.overlap_augment)
import hydra
self._augment = hydra.utils.instantiate(self.cfg.overlap_augment)
# augment on correct device
aug_xs_targ = [self._augment(x_targ) for x_targ in xs_targ]
# checks
Expand Down
9 changes: 5 additions & 4 deletions disent/frameworks/vae/experimental/_unsupervised__dotvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def init_data_overlap_mixin(self):
# initialise
if self.cfg.overlap_augment_mode != 'none':
assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.'
if self.cfg.overlap_augment is not None:
# TODO: this should not reference experiments!
from experiment.util.hydra_utils import instantiate_object_if_needed
self._augment = instantiate_object_if_needed(self.cfg.overlap_augment)
# set augment and instantiate if needed
self._augment = None
if isinstance(self._augment, dict):
import hydra
self._augment = hydra.utils.instantiate(self.cfg.overlap_augment)
assert callable(self._augment), f'augment is not callable: {repr(self._augment)}'
# get overlap loss
overlap_loss = self.cfg.overlap_loss if (self.cfg.overlap_loss is not None) else self.cfg.recon_loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ class cfg(TripletVae.cfg):

def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
# initialise & check augment
self._augment = None
if self.cfg.overlap_augment is not None:
# TODO: this should not reference experiments!
from experiment.util.hydra_utils import instantiate_object_if_needed
self._augment = instantiate_object_if_needed(self.cfg.overlap_augment)
assert callable(self._augment), f'augment is not callable: {repr(self._augment)}'
# set augment and instantiate if needed
self._augment = self.cfg.overlap_augment
if isinstance(self._augment, dict):
import hydra
self._augment = hydra.utils.instantiate(self._augment)
# get default if needed
if self._augment is None:
self._augment = torch.nn.Identity()
warnings.warn(f'{self.__class__.__name__}, no overlap_augment was specified, defaulting to nn.Identity which WILL break things!')
# checks!
assert callable(self._augment), f'augment is not callable: {repr(self._augment)}'

def do_training_step(self, batch, batch_idx):
Expand Down
15 changes: 9 additions & 6 deletions experiment/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@ defaults:
# logs
- metrics: all
- run_callbacks: vis
- run_logging: wandb
- hydra/job_logging: colorlog
- hydra/hydra_logging: colorlog
- run_logging: none
# runtime
- run_location: stampede_shr
- run_launcher: slurm
- hydra/launcher: submitit_slurm
# action
- action: train
# overrides
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog
- override hydra/launcher: submitit_slurm
# so that the defaults list does not override entries in this file
- _self_

job:
user: '${env:USER}'
project: 'test-project'
user: 'n_michlo'
project: 'DELETE'
name: '${framework.name}:${framework.module.recon_loss}|${dataset.name}:${sampling.name}|${trainer.steps}'
seed: NULL

Expand Down
9 changes: 6 additions & 3 deletions experiment/config/config_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ defaults:
- metrics: test
- run_callbacks: none
- run_logging: none
- hydra/job_logging: colorlog
- hydra/hydra_logging: colorlog
# runtime
- run_location: local_cpu
- run_launcher: local
- hydra/launcher: basic
# action
- action: train
# overrides
- hydra/job_logging: colorlog
- hydra/hydra_logging: colorlog
- hydra/launcher: basic
# so that the defaults list does not override entries in this file
- _self_

job:
user: invalid
Expand Down
37 changes: 18 additions & 19 deletions experiment/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from experiment.util.hydra_data import HydraDataModule
from experiment.util.hydra_utils import make_non_strict
from experiment.util.hydra_utils import merge_specializations
from experiment.util.hydra_utils import instantiate_recursive
from experiment.util.run_utils import log_error_and_exit
from experiment.util.run_utils import safe_unset_debug_logger
from experiment.util.run_utils import safe_unset_debug_trainer
Expand Down Expand Up @@ -233,17 +232,14 @@ def hydra_register_schedules(module: DisentFramework, cfg):
if cfg.schedules:
log.info(f'Registering Schedules:')
for target, schedule in cfg.schedules.items():
module.register_schedule(target, instantiate_recursive(schedule), logging=True)
module.register_schedule(target, hydra.utils.instantiate(schedule), logging=True)


def hydra_create_framework_config(cfg):
# create framework config - this is also kinda hacky
# - we need instantiate_recursive because of optimizer_kwargs,
# otherwise the dictionary is left as an OmegaConf dict
framework_cfg: DisentConfigurable.cfg = instantiate_recursive({
**cfg.framework.module,
**dict(_target_=cfg.framework.module._target_ + '.cfg')
})
framework_cfg: DisentConfigurable.cfg = hydra.utils.instantiate(cfg.framework.module)
# warn if some of the cfg variables were not overridden
missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.module.keys())))
if missing_keys:
Expand All @@ -261,18 +257,21 @@ def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cf
# - not supported normally, we need to instantiate to get the class (is there hydra support for this?)
framework_cfg.optimizer = hydra.utils.instantiate(dict(_target_=framework_cfg.optimizer), [torch.Tensor()]).__class__
framework_cfg.optimizer_kwargs = dict(framework_cfg.optimizer_kwargs)
# instantiate
return hydra.utils.instantiate(
dict(_target_=cfg.framework.module._target_),
model=init_model_weights(
AutoEncoder(
encoder=hydra.utils.instantiate(cfg.model.encoder),
decoder=hydra.utils.instantiate(cfg.model.decoder)
), mode=cfg.model.weight_init
),
# apply augmentations to batch on GPU which can be faster than via the dataloader
batch_augment=datamodule.batch_augment,
cfg=framework_cfg
# get framework path
assert str.endswith(cfg.framework.module._target_, '.cfg'), f'`cfg.framework.module._target_` does not end with ".cfg", got: {repr(cfg.framework.module._target_)}'
framework_cls = hydra.utils.get_class(cfg.framework.module._target_[:-len(".cfg")])
# create model
model = AutoEncoder(
encoder=hydra.utils.instantiate(cfg.model.encoder),
decoder=hydra.utils.instantiate(cfg.model.decoder),
)
# initialise the model
model = init_model_weights(model, mode=cfg.model.weight_init)
# create framework
return framework_cls(
model=model,
cfg=framework_cfg,
batch_augment=datamodule.batch_augment, # apply augmentations to batch on GPU which can be faster than via the dataloader
)


Expand Down Expand Up @@ -431,7 +430,7 @@ class ConfigurationError(Exception):
def _error_resolver(msg: str):
raise ConfigurationError(msg)

OmegaConf.register_resolver('exit', _error_resolver)
OmegaConf.register_new_resolver('exit', _error_resolver)

@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_main(cfg: DictConfig):
Expand Down
9 changes: 4 additions & 5 deletions experiment/util/hydra_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from disent.dataset import DisentDataset
from disent.dataset.transform import DisentDatasetTransform
from experiment.util.hydra_utils import instantiate_recursive


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,10 +87,10 @@ def __init__(self, hparams: DictConfig):
else:
self.hparams.update(hparams)
# transform: prepares data from datasets
self.data_transform = instantiate_recursive(self.hparams.dataset.transform)
self.data_transform = hydra.utils.instantiate(self.hparams.dataset.transform)
assert (self.data_transform is None) or callable(self.data_transform)
# input_transform_aug: augment data for inputs, then apply input_transform
self.input_transform = instantiate_recursive(self.hparams.augment.transform)
self.input_transform = hydra.utils.instantiate(self.hparams.augment.transform)
assert (self.input_transform is None) or callable(self.input_transform)
# batch_augment: augments transformed data for inputs, should be applied across a batch
# which version of the dataset we need to use if GPU augmentation is enabled or not.
Expand All @@ -117,12 +116,12 @@ def prepare_data(self) -> None:
# things could go wrong. We try be efficient about it by removing the
# in_memory argument if it exists.
log.info(f'Data - Preparation & Downloading')
instantiate_recursive(data)
hydra.utils.instantiate(data)

def setup(self, stage=None) -> None:
# ground truth data
log.info(f'Data - Instance')
data = instantiate_recursive(self.hparams.dataset.data)
data = hydra.utils.instantiate(self.hparams.dataset.data)
# Wrap the data for the framework some datasets need triplets, pairs, etc.
# Augmentation is done inside the frameworks so that it can be done on the GPU, otherwise things are very slow.
self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.dataset.sampler.cls), transform=self.data_transform, augment=None)
Expand Down
35 changes: 0 additions & 35 deletions experiment/util/hydra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,6 @@
log = logging.getLogger(__name__)


# ========================================================================= #
# Recursive Hydra Instantiation #
# TODO: use https://github.com/facebookresearch/hydra/pull/989 #
# I think this is quicker? Just doesn't perform checks... #
# ========================================================================= #


@deprecated('replace with hydra 1.1')
def call_recursive(config):
# recurse
def _call_recursive(config):
if isinstance(config, (dict, DictConfig)):
c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'}
if '_target_' in config:
config = hydra.utils.instantiate({'_target_': config['_target_']}, **c)
elif isinstance(config, (tuple, list, ListConfig)):
config = [_call_recursive(v) for v in config]
return config
return _call_recursive(config)


# alias
@deprecated('replace with hydra 1.1')
def instantiate_recursive(config):
return call_recursive(config)


@deprecated('replace with hydra 1.1')
def instantiate_object_if_needed(config_or_object):
if isinstance(config_or_object, dict):
return instantiate_recursive(config_or_object)
else:
return config_or_object


# ========================================================================= #
# Better Specializations #
# TODO: this might be replaced by recursive instantiation #
Expand Down
6 changes: 3 additions & 3 deletions requirements-experiment.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ wandb>=0.10.32
# UTILITY
# =======
omegaconf>=2.1.0 # only 2.1.0 supports nested variable interpolation eg. ${group.${group.key}}
hydra-core==1.0.7 # needs omegaconf
hydra-colorlog==1.0.1
hydra-submitit-launcher==1.1.1
hydra-core==1.1.1 # needs omegaconf
hydra-colorlog==1.1.0
hydra-submitit-launcher==1.1.6


# MISSING DEPS - these are imported or referened (_target_) in /experiments, but not included here OR in requirements.txt
Expand Down

0 comments on commit 1361db7

Please sign in to comment.