From 4bbcfa04a3f9238cc8f9bb63ba041a1bc478ed6d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 9 Jul 2020 11:36:21 -0400 Subject: [PATCH] .fit() returns last not best weights in ddp_spawn (#2565) * added base tests for tpu * added base tests for tpu * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint --- .../trainer/distrib_data_parallel.py | 28 +++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 13 ++++++--- tests/models/test_test_loop.py | 21 ++++++++++++++ 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 5b8d79e47564f..3eef40e1ab02a 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC): num_nodes: int node_rank: int tpu_cores: int + testing: bool @property @abstractmethod @@ -555,15 +556,35 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0): # continue training routine results = self.run_pretrain_routine(model) + # persist info in ddp_spawn + self.__transfer_ddp_spawn_state_on_fit_end(model, q, results) + # clean up memory torch.cuda.empty_cache() + if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: + return results + + def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results): + if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']: + return + + # track the best model path + best_model_path = None + if self.checkpoint_callback is not None: + best_model_path = self.checkpoint_callback.best_model_path + if self.global_rank == 0 and q is not None: - q.put(self.checkpoint_callback.best_model_path) + rank_zero_warn('cleaning up ddp environment...') + q.put(best_model_path) q.put(results) - if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn': - return results + # save the last weights + last_path = None + if not self.testing: + last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') + torch.save(model.state_dict(), last_path) + q.put(last_path) def save_spawn_weights(self, model): """ @@ -574,6 +595,7 @@ def save_spawn_weights(self, model): if self.is_global_zero: path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') self.save_checkpoint(path) + return path def load_spawn_weights(self, original_model): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eec21752912b0..c104ab8b2f78d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only import warnings -# warnings to ignore +# warnings to ignore in trainer warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead') @@ -1063,9 +1063,14 @@ def __run_ddp_spawn(self, model, nprocs): # restore main state with best weights best_path = q.get() results = q.get() - if best_path is not None and len(best_path) > 0: - self.checkpoint_callback.best_model_path = best_path - model.load_from_checkpoint(best_path) + last_path = q.get() + + # transfer back the best path to the trainer + self.checkpoint_callback.best_model_path = best_path + + # load last weights + if last_path is not None and not self.testing: + torch.load(last_path, map_location=lambda storage, loc: storage) self.model = model return results diff --git a/tests/models/test_test_loop.py b/tests/models/test_test_loop.py index 141567e465b44..89103116bd8f3 100644 --- a/tests/models/test_test_loop.py +++ b/tests/models/test_test_loop.py @@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir): results = trainer.test() assert 'test_acc' in results + old_weights = model.c_d1.weight.clone().detach().cpu() + results = trainer.test(model) assert 'test_acc' in results + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_dp_test(tmpdir): @@ -45,9 +52,16 @@ def test_dp_test(tmpdir): results = trainer.test() assert 'test_acc' in results + old_weights = model.c_d1.weight.clone().detach().cpu() + results = trainer.test(model) assert 'test_acc' in results + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_ddp_spawn_test(tmpdir): @@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir): results = trainer.test() assert 'test_acc' in results + old_weights = model.c_d1.weight.clone().detach().cpu() + results = trainer.test(model) assert 'test_acc' in results + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights))