Skip to content

Commit

Permalink
more refac
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Jan 8, 2021
1 parent b8554a1 commit 469fd00
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule) -> None:
def restore_weights(self) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,14 @@ def setup_trainer(self, model: LightningModule):
# --------------------------
# Setup??
# --------------------------
ref_model = model
if self.data_parallel:
ref_model = model.module
ref_model = self.get_model()

# set the ranks and devices
self.accelerator_backend.dist.rank = self.global_rank
self.accelerator_backend.dist.device = ref_model.device

# set local properties on the model
self.model_connector.copy_trainer_model_properties(ref_model)
self.model_connector.copy_trainer_model_properties(model)

# init amp. Must be done here instead of __init__ to allow ddp to work
if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,11 @@ def setup_training(self):
"""
Sanity check a few things before starting actual training.
"""
model = self.trainer.model
ref_model = model
if self.trainer.data_parallel:
ref_model = model.module

# --------------------------
# Pre-train
# --------------------------
ref_model = self.trainer.get_model()

# on pretrain routine start
self.trainer.on_pretrain_routine_start(ref_model)
if self.trainer.is_function_implemented("on_pretrain_routine_start"):
Expand All @@ -146,7 +143,7 @@ def setup_training(self):
ref_model.summarize(mode=self.trainer.weights_summary)

# restore training state and model weights before hpc is called
self.trainer.checkpoint_connector.restore_weights(model)
self.trainer.checkpoint_connector.restore_weights()

# on pretrain routine end
self.trainer.on_pretrain_routine_end(ref_model)
Expand Down

0 comments on commit 469fd00

Please sign in to comment.