From 8c5ae2f3dfc1fb2d4c4b7eb12162b2894751db66 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sat, 15 Feb 2020 13:50:08 -0500 Subject: [PATCH] extract training teardown into method, catch KeyboardInterrupt --- pytorch_lightning/trainer/trainer.py | 3 - pytorch_lightning/trainer/training_loop.py | 178 +++++++++++---------- 2 files changed, 95 insertions(+), 86 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fd29ccf65d8fa..c2c4a8ef08c43 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -903,9 +903,6 @@ def run_pretrain_routine(self, model): # CORE TRAINING LOOP self.train() - # summarize profile results - self.profiler.describe() - def test(self, model=None): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3e9a906016b39..845ae809e06ca 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -155,6 +155,7 @@ def training_step(self, batch, batch_idx): import copy import warnings from abc import ABC, abstractmethod +import logging as log import numpy as np @@ -285,90 +286,87 @@ def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) model = self.get_model() - # run all epochs - for epoch in range(self.current_epoch, self.max_epochs): - # set seed for distributed sampler (enables shuffling for each epoch) - if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): - self.get_train_dataloader().sampler.set_epoch(epoch) - # get model - model = self.get_model() - - # update training progress in trainer and model - model.current_epoch = epoch - self.current_epoch = epoch - - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation: - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch - - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - self.batch_loss_value = 0 # accumulated grads - - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.is_iterable_train_dataloader: - # for iterable train loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches - - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' - self.main_progress_bar.set_description(desc) - - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_begin(epoch, self) - - # ----------------- - # RUN TNG EPOCH - # ----------------- - self.run_training_epoch() - - # update LR schedulers - if self.lr_schedulers is not None: - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step(epoch=self.current_epoch) - if self.reduce_lr_on_plateau_scheduler is not None: - val_loss = self.callback_metrics.get('val_loss') - if val_loss is None: - avail_metrics = ','.join(list(self.callback_metrics.keys())) - m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ - f'which is not available. Available metrics are: {avail_metrics}' - raise MisconfigurationException(m) - self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) - - # early stopping - met_min_epochs = epoch >= self.min_epochs - 1 - if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - (met_min_epochs or self.fast_dev_run)): - should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch, - logs=self.callback_metrics) - # stop training - stop = should_stop and met_min_epochs - if stop: - self.main_progress_bar.close() - with self.profiler.profile('on_train_end'): - model.on_train_end() - return - - self.main_progress_bar.close() - - with self.profiler.profile('on_train_end'): - model.on_train_end() - - if self.logger is not None: - self.logger.finalize("success") + try: + # run all epochs + for epoch in range(self.current_epoch, self.max_epochs): + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): + self.get_train_dataloader().sampler.set_epoch(epoch) + + # get model + model = self.get_model() + + # update training progress in trainer and model + model.current_epoch = epoch + self.current_epoch = epoch + + total_val_batches = 0 + is_val_epoch = False + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch + + # total batches includes multiple val checks + self.total_batches = self.num_training_batches + total_val_batches + self.batch_loss_value = 0 # accumulated grads + + if self.fast_dev_run: + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + num_iterations = 2 + elif self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + num_iterations = None + else: + num_iterations = self.total_batches + + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(num_iterations) + desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) + + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin(epoch, self) + + # ----------------- + # RUN TNG EPOCH + # ----------------- + self.run_training_epoch() + + # update LR schedulers + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step(epoch=self.current_epoch) + if self.reduce_lr_on_plateau_scheduler is not None: + val_loss = self.callback_metrics.get('val_loss') + if val_loss is None: + avail_metrics = ','.join(list(self.callback_metrics.keys())) + m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ + f'which is not available. Available metrics are: {avail_metrics}' + raise MisconfigurationException(m) + self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) + + # early stopping + met_min_epochs = epoch >= self.min_epochs - 1 + if (self.enable_early_stop and not self.disable_validation and is_val_epoch and + (met_min_epochs or self.fast_dev_run)): + should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch, + logs=self.callback_metrics) + # stop training + stop = should_stop and met_min_epochs + if stop: + self.run_training_teardown() + return + + self.run_training_teardown() + except KeyboardInterrupt: + log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') + self.run_training_teardown() def run_training_epoch(self): # before epoch hook @@ -574,6 +572,20 @@ def optimizer_closure(): return 0, grad_norm_dic, all_log_metrics + def run_training_teardown(self): + model = self.get_model() + + self.main_progress_bar.close() + + with self.profiler.profile('on_train_end'): + model.on_train_end() + + if self.logger is not None: + self.logger.finalize("success") + + # summarize profile results + self.profiler.describe() + def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...)