From 5b49a5761379034f12755f981fed3a6e711d88c7 Mon Sep 17 00:00:00 2001 From: Vedanuj Goswami Date: Mon, 6 May 2019 11:42:14 -0700 Subject: [PATCH 1/2] Fix pth filepath and save state_dict for model only during finalize --- pythia/utils/checkpoint.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pythia/utils/checkpoint.py b/pythia/utils/checkpoint.py index a5ceb7545..26637d368 100644 --- a/pythia/utils/checkpoint.py +++ b/pythia/utils/checkpoint.py @@ -36,15 +36,14 @@ def __init__(self, trainer): self.config["log_foldername"] = self.ckpt_foldername self.ckpt_foldername = os.path.join(self.save_dir, self.ckpt_foldername) self.pth_filepath = os.path.join( - self.save_dir, self.ckpt_foldername, + self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth" ) self.params_filepath = os.path.join( - self.save_dir, self.ckpt_foldername, + self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_params.pth" ) - self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not os.path.exists(self.models_foldername): os.makedirs(self.models_foldername) @@ -241,5 +240,4 @@ def restore(self): self.trainer.optimizer.load_state_dict(ckpt["optimizer"]) def finalize(self): - torch.save(self.trainer.model, self.pth_filepath) - torch.save(self.trainer.model.state_dict(), self.params_filepath) + torch.save(self.trainer.model.state_dict(), self.pth_filepath) From 1b526f69ad13605e394d35b1e4c18c9cc538e9e8 Mon Sep 17 00:00:00 2001 From: Vedanuj Goswami Date: Wed, 8 May 2019 11:14:33 -0700 Subject: [PATCH 2/2] Remove unnecessary variable and rebase --- pythia/utils/checkpoint.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pythia/utils/checkpoint.py b/pythia/utils/checkpoint.py index 26637d368..e1d92b0dc 100644 --- a/pythia/utils/checkpoint.py +++ b/pythia/utils/checkpoint.py @@ -36,12 +36,7 @@ def __init__(self, trainer): self.config["log_foldername"] = self.ckpt_foldername self.ckpt_foldername = os.path.join(self.save_dir, self.ckpt_foldername) self.pth_filepath = os.path.join( - self.ckpt_foldername, - self.ckpt_prefix + self.model_name + "_final.pth" - ) - self.params_filepath = os.path.join( - self.ckpt_foldername, - self.ckpt_prefix + self.model_name + "_params.pth" + self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth" ) self.models_foldername = os.path.join(self.ckpt_foldername, "models")