Skip to content

Commit

Permalink
_launch refactor and types [1/n] (#7232)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 28, 2021
1 parent 0c6c078 commit bdc4272
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")

Expand All @@ -37,7 +38,7 @@ class TrainingTypePlugin(Plugin, ABC):

def __init__(self) -> None:
self._model = None
self._results = None
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
self._call_configure_sharded_model_hook = True

def connect(self, model: Module) -> None:
Expand Down Expand Up @@ -124,12 +125,12 @@ def lightning_module(self) -> 'pl.LightningModule':
return unwrap_lightning_module(self._model)

@property
def results(self) -> Any:
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
"""
The results of the last training/testing run will be cached here.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
# TODO(@awaelchli): improve these docs
return self._results

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT


class LoggerConnector:
Expand Down Expand Up @@ -267,7 +268,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self):
def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:
if not self.trainer.sanity_checking:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache


Expand All @@ -31,6 +32,7 @@ def __init__(self, trainer):
self.warning_cache = WarningCache()
self.batch_indices: Optional[List[int]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self.predictions: Optional[List[List[Any]]] = None
# `DDPSpawnPlugin` plugins and derivate don't support return predictions.
self._return_predictions: Optional[bool] = None
self._previous_grad_status: Optional[bool] = None
Expand Down Expand Up @@ -138,10 +140,10 @@ def on_predict_start(self) -> None:
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[Union[List[Any], List[List[Any]]]]:
def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
self.trainer.profiler.describe()

results: List[List[Any]] = self.predictions
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)

Expand Down
118 changes: 60 additions & 58 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT

log = logging.getLogger(__name__)
# warnings to ignore in trainer
Expand Down Expand Up @@ -408,36 +409,13 @@ def __init__(
# Callback system
self.on_init_end()

def fit(
def _launch(
self,
model: LightningModule,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
r"""
Runs the full optimization routine.
Args:
datamodule: A instance of :class:`LightningDataModule`.
model: Model to fit.
train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
"""
Trainer._log_api_event("fit")
# we reuse fit for other functions. When already set, it shouldn't be modified.
if not self.state.running:
self.state = TrainerState.FITTING
if self._running_stage is None or self.tuning:
self.training = True

) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]:
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)

Expand Down Expand Up @@ -545,18 +523,14 @@ def dispatch(self):
else:
self.accelerator.start_training(self)

def run_stage(self):
results = None

def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
self.profile_connector.setup()

if self.evaluating:
results = self.run_evaluate()
elif self.predicting:
results = self.run_predict()
else:
self.run_train()
return results
return self.run_evaluate()
if self.predicting:
return self.run_predict()
return self.run_train()

def _pre_training_routine(self):
# wait for all to join if on distributed
Expand Down Expand Up @@ -586,7 +560,6 @@ def _pre_training_routine(self):
ref_model.on_pretrain_routine_end()

def run_train(self) -> None:

self._pre_training_routine()

if not self.is_global_zero and self.progress_bar_callback is not None:
Expand Down Expand Up @@ -660,7 +633,7 @@ def run_train(self) -> None:
self._running_stage = None
raise

def run_evaluation(self, on_epoch=False):
def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}."
Expand Down Expand Up @@ -777,7 +750,7 @@ def track_output_for_epoch_end(self, outputs, output):
outputs.append(output)
return outputs

def run_evaluate(self):
def run_evaluate(self) -> _EVALUATE_OUTPUT:
if not self.is_global_zero and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

Expand All @@ -786,9 +759,6 @@ def run_evaluate(self):
with self.profiler.profile(f"run_{self._running_stage}_evaluation"):
eval_loop_results = self.run_evaluation()

if len(eval_loop_results) == 0:
return 1

# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
Expand All @@ -798,7 +768,7 @@ def run_evaluate(self):

return eval_loop_results

def run_predict(self):
def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

Expand Down Expand Up @@ -860,14 +830,50 @@ def run_sanity_check(self, ref_model):
# prevents sanity check to affect random sampling in training
reset_seed()

def fit(
self,
model: LightningModule,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[int]:
r"""
Runs the full optimization routine.
Args:
model: Model to fit.
train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
datamodule: A instance of :class:`LightningDataModule`.
"""
Trainer._log_api_event("fit")

self.state = TrainerState.FITTING
self.training = True

results = self._launch(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)

assert self.state.stopped
self.training = False

return results

def validate(
self,
model: Optional[LightningModule] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the validation set.
Expand Down Expand Up @@ -914,10 +920,10 @@ def validate(
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)

# run validate
results = self.fit(model)
results = self._launch(model)

assert self.state.stopped
self.validating = False
Expand All @@ -931,7 +937,7 @@ def test(
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the test set. It's separated from
fit to make sure you never run on your test set until you want to.
Expand Down Expand Up @@ -975,21 +981,17 @@ def test(
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)

# run test
results = self.fit(model)
results = self._launch(model)

assert self.state.stopped
self.testing = False

return results

def __load_ckpt_weights(
self,
model,
ckpt_path: Optional[str] = None,
) -> Optional[str]:
def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
return

Expand Down Expand Up @@ -1031,15 +1033,17 @@ def predict(
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
):
) -> Optional[_PREDICT_OUTPUT]:
r"""
Separates from fit to make sure you never run on your predictions set until you want to.
This will call the model forward function to compute predictions.
Args:
model: The model to predict with.
dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples.
datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders.
return_predictions: Whether to return predictions.
Expand All @@ -1063,16 +1067,14 @@ def predict(
self.predicting = True

if dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`')

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)

results = self.fit(model)
results = self._launch(model)

assert self.state.stopped
self.predicting = False
Expand All @@ -1085,7 +1087,7 @@ def tune(
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
) -> None:
r"""
Runs routines to tune hyperparameters before training.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model, **fit_kwargs)
trainer.tuner._launch(model, **fit_kwargs)
# Double in size
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model, **fit_kwargs)
trainer.tuner._launch(model, **fit_kwargs)
count += 1
if count > max_trials:
break
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def lr_find(
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

# Fit, lr & loss logged in callback
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)
trainer.tuner._launch(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)

# Prompt if we stopped early
if trainer.global_step != num_training:
Expand Down
Loading

0 comments on commit bdc4272

Please sign in to comment.