Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce Lightning module as source of truth for automatic optimization #7130

Merged
merged 10 commits into from
Apr 26, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))


- Removed evaluation loop legacy returns for `*_epoch_end` hooks ([#6973](https://github.com/PyTorchLightning/pytorch-lightning/pull/6973))


Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def backward(self, loss, optimizer, optimizer_idx):
loss.backward()

"""
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
if self.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
Expand Down Expand Up @@ -1539,7 +1539,7 @@ def get_progress_bar_dict(self):
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif self.trainer.train_loop.automatic_optimization:
elif self.automatic_optimization:
avg_training_loss = float('NaN')

tqdm_dict = {}
Expand All @@ -1558,7 +1558,7 @@ def get_progress_bar_dict(self):
return tqdm_dict

def _verify_is_manual_optimization(self, fn_name):
if self.trainer.train_loop.automatic_optimization:
if self.automatic_optimization:
raise MisconfigurationException(
f'to use {fn_name}, please disable automatic optimization:'
' set model property `automatic_optimization` as False'
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __verify_train_loop_configuration(self, model):

trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
automatic_optimization = trainer.train_loop.automatic_optimization
automatic_optimization = model.automatic_optimization
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
return epoch_log_metrics, epoch_progress_bar_metrics

def log_train_step_metrics(self, batch_output):
if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return
_, batch_log_metrics = self.cached_results.update_logger_connector()
# when metrics should be logged
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def __init__(self, trainer):
def copy_trainer_model_properties(self, model):
ref_model = self.trainer.lightning_module or model

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization

for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m._device_type = str(self.trainer._device_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
"""
if not self.trainer.lr_schedulers or not self.trainer.train_loop.automatic_optimization:
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

is_manual_optimization = not self.train_loop.automatic_optimization
is_manual_optimization = not model.automatic_optimization
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

Expand Down
23 changes: 12 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, trainer, multiple_trainloader_mode: str):
self.warning_cache = WarningCache()
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
self.automatic_optimization = True
self._curr_step_result = None
self._cur_grad_norm_dict = None
self._multiple_trainloader_mode = multiple_trainloader_mode
Expand Down Expand Up @@ -255,7 +254,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def _check_training_step_output(self, training_step_output):
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
Expand Down Expand Up @@ -290,7 +289,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
closure_loss = None
untouched_loss = None

if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:
# accumulate loss. if accumulate_grad_batches==1, no effect
closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches

Expand Down Expand Up @@ -660,7 +659,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients

else:
if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:

def train_step_and_backward_closure():
result = self.training_step_and_backward(
Expand Down Expand Up @@ -717,7 +716,7 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
"""
if (
isinstance(self.trainer.training_type_plugin, ParallelPlugin)
and (self.automatic_optimization or should_block_sync)
and (self.trainer.lightning_module.automatic_optimization or should_block_sync)
):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
Expand All @@ -740,7 +739,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)

if self.automatic_optimization:
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(opt_closure_result.loss)

Expand All @@ -755,7 +754,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result

if not self._skip_backward and self.automatic_optimization:
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0

if is_first_batch_to_accumulate:
Expand Down Expand Up @@ -858,14 +857,16 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.automatic_optimization:
if not self.trainer.lightning_module.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization:
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
Expand All @@ -886,7 +887,7 @@ def save_loggers_on_train_batch_end(self):
def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once
optimizers = self.get_optimizers_iterable()
if not self.automatic_optimization:
if not self.trainer.lightning_module.automatic_optimization:
optimizers = [optimizers[0]]
return optimizers

Expand All @@ -896,7 +897,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):

# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.toggle_optimizer(optimizer, opt_idx)

Expand Down