Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_launch refactor and types [1/n] #7232

Merged
merged 9 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")

Expand All @@ -37,7 +38,7 @@ class TrainingTypePlugin(Plugin, ABC):

def __init__(self) -> None:
self._model = None
self._results = None
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
self._call_configure_sharded_model_hook = True

def connect(self, model: Module) -> None:
Expand Down Expand Up @@ -124,12 +125,12 @@ def lightning_module(self) -> 'pl.LightningModule':
return unwrap_lightning_module(self._model)

@property
def results(self) -> Any:
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
"""
The results of the last training/testing run will be cached here.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
# TODO(@awaelchli): improve these docs
return self._results

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT


class LoggerConnector:
Expand Down Expand Up @@ -267,7 +268,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self):
def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:
if not self.trainer.sanity_checking:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache


Expand All @@ -31,6 +32,7 @@ def __init__(self, trainer):
self.warning_cache = WarningCache()
self.batch_indices: Optional[List[int]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self.predictions: Optional[List[List[Any]]] = None
# `DDPSpawnPlugin` plugins and derivate don't support return predictions.
self._return_predictions: Optional[bool] = None
self._previous_grad_status: Optional[bool] = None
Expand Down Expand Up @@ -133,10 +135,10 @@ def on_predict_start(self) -> None:
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[Union[List[Any], List[List[Any]]]]:
def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
self.trainer.profiler.describe()

results: List[List[Any]] = self.predictions
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)

Expand Down
112 changes: 57 additions & 55 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT

log = logging.getLogger(__name__)
# warnings to ignore in trainer
Expand Down Expand Up @@ -408,36 +409,13 @@ def __init__(
# Callback system
self.on_init_end()

def fit(
def _fit_impl(
self,
model: LightningModule,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
r"""
Runs the full optimization routine.

Args:
datamodule: A instance of :class:`LightningDataModule`.

model: Model to fit.

train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`

val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

"""
Trainer._log_api_event("fit")
# we reuse fit for other functions. When already set, it shouldn't be modified.
if not self.state.running:
self.state = TrainerState.FITTING
if self._running_stage is None or self.tuning:
self.training = True

) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]:
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)

Expand Down Expand Up @@ -545,18 +523,14 @@ def dispatch(self):
else:
self.accelerator.start_training(self)

def run_stage(self):
results = None

def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
self.profile_connector.setup()

if self.evaluating:
results = self.run_evaluate()
elif self.predicting:
results = self.run_predict()
else:
self.run_train()
return results
return self.run_evaluate()
if self.predicting:
return self.run_predict()
return self.run_train()

def _pre_training_routine(self):
# wait for all to join if on distributed
Expand Down Expand Up @@ -586,7 +560,6 @@ def _pre_training_routine(self):
ref_model.on_pretrain_routine_end()

def run_train(self) -> None:

self._pre_training_routine()

if not self.is_global_zero and self.progress_bar_callback is not None:
Expand Down Expand Up @@ -660,7 +633,7 @@ def run_train(self) -> None:
self._running_stage = None
raise

def run_evaluation(self, on_epoch=False):
def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}."
Expand Down Expand Up @@ -777,7 +750,7 @@ def track_output_for_epoch_end(self, outputs, output):
outputs.append(output)
return outputs

def run_evaluate(self):
def run_evaluate(self) -> _EVALUATE_OUTPUT:
if not self.is_global_zero and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

Expand All @@ -786,9 +759,6 @@ def run_evaluate(self):
with self.profiler.profile(f"run_{self._running_stage}_evaluation"):
eval_loop_results = self.run_evaluation()

if len(eval_loop_results) == 0:
return 1

# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
Expand All @@ -798,7 +768,7 @@ def run_evaluate(self):

return eval_loop_results

def run_predict(self):
def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

Expand Down Expand Up @@ -860,14 +830,50 @@ def run_sanity_check(self, ref_model):
# prevents sanity check to affect random sampling in training
reset_seed()

def fit(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we want to move this up? It would be nice if the entry points fit, validate, test came first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Let's try to organise, so it is simple to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this in a separate PR. Keeping them where they are for now so they are together

self,
model: LightningModule,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[int]:
r"""
Runs the full optimization routine.

Args:
model: Model to fit.

train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`

val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

datamodule: A instance of :class:`LightningDataModule`.
"""
Trainer._log_api_event("fit")

self.state = TrainerState.FITTING
self.training = True

results = self._fit_impl(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)

assert self.state.stopped
self.training = False

return results

def validate(
self,
model: Optional[LightningModule] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the validation set.

Expand Down Expand Up @@ -914,10 +920,10 @@ def validate(
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)

# run validate
results = self.fit(model)
results = self._fit_impl(model)

assert self.state.stopped
self.validating = False
Expand All @@ -931,7 +937,7 @@ def test(
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the test set. It's separated from
fit to make sure you never run on your test set until you want to.
Expand Down Expand Up @@ -975,21 +981,17 @@ def test(
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)

# run test
results = self.fit(model)
results = self._fit_impl(model)

assert self.state.stopped
self.testing = False

return results

def __load_ckpt_weights(
self,
model,
ckpt_path: Optional[str] = None,
) -> Optional[str]:
def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
return

Expand Down Expand Up @@ -1031,7 +1033,7 @@ def predict(
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
):
) -> Optional[_PREDICT_OUTPUT]:
r"""

Separates from fit to make sure you never run on your predictions set until you want to.
Expand Down Expand Up @@ -1072,7 +1074,7 @@ def predict(
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)

results = self.fit(model)
results = self._fit_impl(model)

assert self.state.stopped
self.predicting = False
Expand All @@ -1085,7 +1087,7 @@ def tune(
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
) -> None:
r"""
Runs routines to tune hyperparameters before training.

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model, **fit_kwargs)
trainer.tuner._fit(model, **fit_kwargs)
# Double in size
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model, **fit_kwargs)
trainer.tuner._fit(model, **fit_kwargs)
count += 1
if count > max_trials:
break
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def lr_find(
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

# Fit, lr & loss logged in callback
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)
trainer.tuner._fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)

# Prompt if we stopped early
if trainer.global_step != num_training:
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
from typing import Any, List, Optional, Union

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -71,6 +71,13 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):

self.trainer.state = TrainerState.FINISHED

def _fit(self, *args: Any, **kwargs: Any) -> None:
"""`_fit_impl` wrapper to set the proper state during tuning, as this can be called multiple times"""
self.trainer.state = TrainerState.TUNING # last `_fit_impl` call might have set it to `FINISHED`
self.trainer.training = True
self.trainer._fit_impl(*args, **kwargs)
self.trainer.tuning = True

def scale_batch_size(
self,
model,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
_METRIC = Union[Metric, torch.Tensor, int, float]
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
EPOCH_OUTPUT = List[STEP_OUTPUT]
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader
_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]]
_PARAMETERS = Iterator[torch.nn.Parameter]