Skip to content

Commit

Permalink
Enforce Lightning module as source of truth for automatic optimization (
Browse files Browse the repository at this point in the history
#7130)

* make lightning module source of truth for automatic optimization

* Update configuration_validator.py

* Update model_connector.py

* rm-references

* Update CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: jirka <[email protected]>
  • Loading branch information
ananthsub and Borda authored Apr 26, 2021
1 parent 44d775f commit 68eac4d
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))


- Removed evaluation loop legacy returns for `*_epoch_end` hooks ([#6973](https://github.com/PyTorchLightning/pytorch-lightning/pull/6973))


Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
"""
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
if self.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
Expand Down Expand Up @@ -1539,7 +1539,7 @@ def get_progress_bar_dict(self):
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif self.trainer.train_loop.automatic_optimization:
elif self.automatic_optimization:
avg_training_loss = float('NaN')

tqdm_dict = {}
Expand All @@ -1558,7 +1558,7 @@ def get_progress_bar_dict(self):
return tqdm_dict

def _verify_is_manual_optimization(self, fn_name):
if self.trainer.train_loop.automatic_optimization:
if self.automatic_optimization:
raise MisconfigurationException(
f'to use {fn_name}, please disable automatic optimization:'
' set model property `automatic_optimization` as False'
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __verify_train_loop_configuration(self, model):

trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
automatic_optimization = trainer.train_loop.automatic_optimization
automatic_optimization = model.automatic_optimization
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
return epoch_log_metrics, epoch_progress_bar_metrics

def log_train_step_metrics(self, batch_output):
if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return
_, batch_log_metrics = self.cached_results.update_logger_connector()
# when metrics should be logged
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def __init__(self, trainer):
def copy_trainer_model_properties(self, model):
ref_model = self.trainer.lightning_module or model

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization

for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m._device_type = str(self.trainer._device_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
"""
if not self.trainer.lr_schedulers or not self.trainer.train_loop.automatic_optimization:
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

is_manual_optimization = not self.train_loop.automatic_optimization
is_manual_optimization = not model.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

Expand Down
23 changes: 12 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, trainer, multiple_trainloader_mode: str):
self.warning_cache = WarningCache()
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
self.automatic_optimization = True
self._curr_step_result = None
self._cur_grad_norm_dict = None
self._multiple_trainloader_mode = multiple_trainloader_mode
Expand Down Expand Up @@ -255,7 +254,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def _check_training_step_output(self, training_step_output):
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
Expand Down Expand Up @@ -290,7 +289,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
closure_loss = None
untouched_loss = None

if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:
# accumulate loss. if accumulate_grad_batches==1, no effect
closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches

Expand Down Expand Up @@ -660,7 +659,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients

else:
if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:

def train_step_and_backward_closure():
result = self.training_step_and_backward(
Expand Down Expand Up @@ -717,7 +716,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
"""
if (
isinstance(self.trainer.training_type_plugin, ParallelPlugin)
and (self.automatic_optimization or should_block_sync)
and (self.trainer.lightning_module.automatic_optimization or should_block_sync)
):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
Expand All @@ -740,7 +739,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)

if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(opt_closure_result.loss)

Expand All @@ -755,7 +754,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result

if not self._skip_backward and self.automatic_optimization:
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0

if is_first_batch_to_accumulate:
Expand Down Expand Up @@ -858,14 +857,16 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.automatic_optimization:
if not self.trainer.lightning_module.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization:
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
Expand All @@ -886,7 +887,7 @@ def save_loggers_on_train_batch_end(self):
def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once
optimizers = self.get_optimizers_iterable()
if not self.automatic_optimization:
if not self.trainer.lightning_module.automatic_optimization:
optimizers = [optimizers[0]]
return optimizers

Expand All @@ -896,7 +897,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):

# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.toggle_optimizer(optimizer, opt_idx)

Expand Down

0 comments on commit 68eac4d

Please sign in to comment.