Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] 1/2 Add trainer.predict #5579

Merged
merged 60 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9babf3f
start adding predict
tchaton Jan 19, 2021
f6261fa
add predict
tchaton Jan 19, 2021
86aa7d4
resolve test
tchaton Jan 19, 2021
4fb75d7
add predict
tchaton Jan 20, 2021
d3c9130
Merge branch 'feat_predict' of https://github.com/PyTorchLightning/py…
tchaton Jan 20, 2021
2d7ee29
remove limit_predict
tchaton Jan 20, 2021
8b8d974
update
tchaton Jan 20, 2021
a8415c5
add test for predict
tchaton Jan 20, 2021
c59a17b
typo
tchaton Jan 20, 2021
742e01a
Merge branch 'release/1.2-dev' into feat_predict
tchaton Jan 20, 2021
4bfbef1
update on comments
tchaton Jan 20, 2021
5184e56
remove predict_step
tchaton Jan 20, 2021
4d5f57d
update ddp_shareded
tchaton Jan 20, 2021
7fa90c8
check ddp_sharded
tchaton Jan 20, 2021
392e04f
resolve on comments
tchaton Jan 20, 2021
036e24d
resolve isort
tchaton Jan 20, 2021
4b5525b
update dp
tchaton Jan 20, 2021
97fa5b3
add test dp 1 gpu
tchaton Jan 20, 2021
b6ed163
made default forward
tchaton Jan 20, 2021
bb68b9e
resolve path
tchaton Jan 20, 2021
dc6735d
resolve bug
tchaton Jan 20, 2021
cf38143
update on comments
tchaton Jan 20, 2021
6de15f3
resolve doc
tchaton Jan 20, 2021
147bdd7
resolve bug
tchaton Jan 20, 2021
895a4ba
update
tchaton Jan 20, 2021
45ebba9
resolve bug
tchaton Jan 20, 2021
e58c4e7
update on comments
tchaton Jan 20, 2021
563467f
Merge branch 'feat_predict' of https://github.com/PyTorchLightning/py…
tchaton Jan 20, 2021
6b4f76f
resolve pep8
tchaton Jan 20, 2021
f0bdbd3
update test doc
tchaton Jan 20, 2021
0a2efb2
update on comments
tchaton Jan 20, 2021
992f360
solve special tests
tchaton Jan 21, 2021
a366c2c
resolve bug
tchaton Jan 21, 2021
0efa4df
Merge branch 'release/1.2-dev' into feat_predict
tchaton Jan 26, 2021
7b2a4e7
Merge branch 'release/1.2-dev' into feat_predict
tchaton Jan 26, 2021
9137b16
resolve flake8
tchaton Jan 26, 2021
b4b860f
Update pytorch_lightning/callbacks/progress.py
tchaton Jan 26, 2021
7526efe
Update pytorch_lightning/trainer/trainer.py
tchaton Jan 26, 2021
3fbf983
Merge branch 'release/1.2-dev' into feat_predict
tchaton Jan 26, 2021
dca009d
add predict to LightningModule
tchaton Jan 26, 2021
1b1a8aa
missing predict
tchaton Jan 26, 2021
21710f9
typo
tchaton Jan 26, 2021
4752bd7
rename is_prediction to _predicting
tchaton Jan 26, 2021
43442a0
add
tchaton Jan 26, 2021
a448657
Merge branch 'release/1.2-dev' into feat_predict
tchaton Jan 26, 2021
e7fd0d6
update
tchaton Jan 26, 2021
4a1ce07
Merge branch 'feat_predict' of https://github.com/PyTorchLightning/py…
tchaton Jan 26, 2021
d3bf981
Merge branch 'release/1.2-dev' into feat_predict
mergify[bot] Jan 26, 2021
8765a8c
Merge branch 'release/1.2-dev' into feat_predict
mergify[bot] Jan 26, 2021
62500e0
Merge branch 'release/1.2-dev' into feat_predict
mergify[bot] Jan 26, 2021
4813470
Merge branch 'feat_predict' of https://github.com/PyTorchLightning/py…
tchaton Jan 26, 2021
5a3a110
update
tchaton Jan 26, 2021
bd5c4c5
update doc
tchaton Jan 26, 2021
c59578b
Merge branch 'release/1.2-dev' into feat_predict
mergify[bot] Jan 26, 2021
56ab0c7
Merge branch 'release/0.2-dev' of https://github.com/PyTorchLightning…
tchaton Jan 27, 2021
5aa890b
Merge branch 'release/1.2-dev' into feat_predict
mergify[bot] Jan 27, 2021
a290719
chlog
Borda Jan 27, 2021
30da4da
Apply suggestions from code review
Borda Jan 27, 2021
1cce4f9
Branch was auto-updated.
github-actions[bot] Jan 27, 2021
07d111d
Merge branch 'release/1.2-dev' of https://github.com/PyTorchLightning…
tchaton Jan 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def forward(self, args):
return self._step(self.trainer.model.forward, args)

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def forward(self, args):
return self._step(args)

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def forward(self, args):
return self._step(self.trainer.model.forward, args)

def to_device(self, batch):
gpu_id = 0
if isinstance(self.trainer.data_parallel_device_ids, list):
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def forward(self, args):
return self._step(self.trainer.model.forward, args)

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
optimizer.synchronize()
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def forward(self, args):
return self._step(self.trainer.model.forward, args)

def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,13 @@ def init_validation_tqdm(self) -> tqdm:
)
return bar

def init_test_tqdm(self) -> tqdm:
def init_test_tqdm(self, trainer=None) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
if trainer is not None and getattr(trainer, "is_predicting", False):
desc = "Predicting"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
bar = tqdm(
desc='Testing',
desc=desc,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
Expand Down Expand Up @@ -361,7 +364,7 @@ def on_train_end(self, trainer, pl_module):

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

"""nn.Module with additional great features."""

from abc import ABC
from argparse import Namespace
import collections
import copy
from functools import partial
import inspect
import os
from pathlib import Path
import re
import tempfile
from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -66,6 +66,7 @@ class LightningModule(
"on_gpu",
"current_epoch",
"global_step",
"running_stage",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, *args, **kwargs):
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self.running_stage = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down
44 changes: 34 additions & 10 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache


Expand Down Expand Up @@ -78,14 +79,22 @@ def forward(self, *inputs, **kwargs):
"them on device: {}".format(self.src_device_obj, t.device))

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)

if len(self.device_ids) == 1:
# lightning
if self.module.training:

running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
return self.module.training_step(*inputs[0], **kwargs[0])
if self.module.testing:

carmocca marked this conversation as resolved.
Show resolved Hide resolved
elif running_stage == RunningStage.TESTING:
return self.module.test_step(*inputs[0], **kwargs[0])

return self.module.validation_step(*inputs[0], **kwargs[0])
elif running_stage == RunningStage.EVALUATING:
return self.module.validation_step(*inputs[0], **kwargs[0])

else:
return self.module(*inputs[0], **kwargs[0])

replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
Expand Down Expand Up @@ -187,15 +196,24 @@ def __init__(self, pl_module: LightningModule):
self.module = pl_module

def forward(self, *inputs, **kwargs):
if self.module.training:

running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif self.module.testing:

elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
else:

elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

else:
output = self.module(*inputs, **kwargs)

return output


Expand Down Expand Up @@ -276,16 +294,22 @@ def _worker(i, module, input, kwargs, device=None):

# ---------------
# CHANGE
if module.training:
if module.running_stage == RunningStage.TRAINING:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
output = module.training_step(*input, **kwargs)
fx_called = 'training_step'
elif module.testing:

elif module.running_stage == RunningStage.TESTING:
output = module.test_step(*input, **kwargs)
fx_called = 'test_step'
else:

elif module.running_stage == RunningStage.EVALUATING:
output = module.validation_step(*input, **kwargs)
fx_called = 'validation_step'

else:
output = module(*input, **kwargs)
fx_called = 'forward'

if output is None:
warn_missing_output(fx_called)

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE

LightningShardedDataParallel = None
Expand All @@ -23,10 +24,18 @@ def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()

if self.module.training:
running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
outputs = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:

elif running_stage == RunningStage.TESTING:
outputs = self.module.test_step(*inputs, **kwargs)
else:

elif running_stage == RunningStage.EVALUATING:
outputs = self.module.validation_step(*inputs, **kwargs)

else:
outputs = self.module(*inputs, **kwargs)

return outputs
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model):
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader:
if not has_train_dataloader and not self.trainer.is_predicting:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
Expand All @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model):
# verify model has optimizer
# -----------------------------------
has_optimizers = is_overridden('configure_optimizers', model)
if not has_optimizers:
if not has_optimizers and not self.trainer.is_predicting:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DistributedType, LightningEnum


class LoggerStages(LightningEnum):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
class LoggerStages(str, Enum):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
""" Train/validation/test phase in each training step.

>>> # you can math the type with string
>>> LoggerStages.TRAIN == 'train'
True
Expand Down Expand Up @@ -371,7 +372,7 @@ def update_logger_connector(self) -> None:
callback_metrics = {}
batch_pbar_metrics = {}
batch_log_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value
is_train = self._stage in RunningStage.TRAINING
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if not self._has_batch_loop_finished:
# get pbar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DeviceType, flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -37,9 +38,9 @@ def __init__(self, trainer):
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages}
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
self._cached_results[None] = EpochResultStore(trainer, None)
self._callback_hook_validator = CallbackHookNameValidator()
self._current_stage = None

@property
def callback_metrics(self) -> Dict:
Expand Down Expand Up @@ -75,7 +76,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self._current_stage) # type: ignore
return self._cached_results.get(self.trainer._running_stage) # type: ignore

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
Expand All @@ -90,10 +91,8 @@ def set_metrics(self, key: str, val: Any) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
self._current_stage = LoggerStages.determine_stage(stage_or_testing)
if reset:
self.cached_results.reset()
def reset(self) -> None:
self.cached_results.reset()

def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None:
self._callback_hook_validator.check_logging_in_callbacks(
Expand All @@ -119,8 +118,7 @@ def on_train_batch_end(self) -> None:
self.cached_results._batch_size = None

def cache_logged_metrics(self):
if self._current_stage is not None:
self._cached_results[self._current_stage].cache_result()
self._cached_results[self.trainer._running_stage].cache_result()

def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
Expand Down
Loading