Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PredictLoop #5752

Merged
merged 387 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
387 commits
Select commit Hold shift + click to select a range
3792b72
integrate distrib_type
awaelchli Jan 31, 2021
ef85b81
sync changes
awaelchli Jan 31, 2021
9d9a940
sync
awaelchli Feb 1, 2021
f017a39
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
a190a56
fixes
awaelchli Feb 1, 2021
73bb607
add forgotten generators
awaelchli Feb 1, 2021
c8c74f3
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
ae71997
add missing logic
awaelchli Feb 1, 2021
d89847b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
0e686c3
update
awaelchli Feb 1, 2021
d6a43ea
import
awaelchli Feb 1, 2021
ceb8f75
missed imports
awaelchli Feb 1, 2021
fbb7c20
import fixes
awaelchli Feb 1, 2021
b610999
isort
awaelchli Feb 1, 2021
9b79924
mv f
awaelchli Feb 1, 2021
9afe54d
changelog
awaelchli Feb 1, 2021
3b63e82
Merge branch 'release/1.2-dev' into ref/update-plugins
awaelchli Feb 1, 2021
ca8cb68
format
awaelchli Feb 1, 2021
0633745
move helper to parallel plugin
awaelchli Feb 1, 2021
a622e0b
d
awaelchli Feb 1, 2021
18c682f
Merge branch 'ref/update-plugins' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
f275803
add world size
awaelchli Feb 1, 2021
4ae008b
clean up
awaelchli Feb 1, 2021
3b3918b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
d4c6308
duplicate
awaelchli Feb 1, 2021
7eef4a0
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 2, 2021
9949164
activate ddp_sharded and tpu
awaelchli Feb 2, 2021
6d47357
set nvidia flags
awaelchli Feb 2, 2021
a6864ec
remove unused colab var
awaelchli Feb 2, 2021
b4b9724
use_tpu <-> on_tpu attrs
awaelchli Feb 2, 2021
81001e3
make some ddp_cpu and clusterplugin tests pass
awaelchli Feb 2, 2021
cea000d
Ref/accelerator connector (#5742)
justusschock Feb 2, 2021
933e2a1
plugins
awaelchli Feb 2, 2021
a97afb4
add predict_loop
tchaton Feb 2, 2021
ad451d8
manual optimization
justusschock Feb 2, 2021
c956c54
clean predictloop
tchaton Feb 2, 2021
a30a3cf
update optimizer routing
justusschock Feb 2, 2021
0ecb3f7
add predict loop on new accelerator
tchaton Feb 2, 2021
bbb8416
resolve a bug
tchaton Feb 2, 2021
a05b291
add rank to torchelastic
justusschock Feb 2, 2021
53efe55
add predict_loop
tchaton Feb 2, 2021
1c9d57e
add predict loop on new accelerator
tchaton Feb 2, 2021
154dae2
resolve a bug
tchaton Feb 2, 2021
4388e73
fix memory mixed precision
awaelchli Feb 2, 2021
872af55
Merge branch 'accelerator-refactor-sharded' into introduce_predict_lo…
tchaton Feb 2, 2021
8df9893
update
tchaton Feb 2, 2021
be9d029
setstate on trainer for pickling in ddp spawn
awaelchli Feb 2, 2021
8369fb2
add predict_loop
tchaton Feb 2, 2021
4c908f7
clean predictloop
tchaton Feb 2, 2021
c5b942f
add predict loop on new accelerator
tchaton Feb 2, 2021
421279d
resolve a bug
tchaton Feb 2, 2021
05e60ed
add predict_loop
tchaton Feb 2, 2021
2226640
add predict loop on new accelerator
tchaton Feb 2, 2021
4a29bb2
resolve a bug
tchaton Feb 2, 2021
8d63beb
add predict_loop
tchaton Feb 2, 2021
2a24e41
add predict loop on new accelerator
tchaton Feb 2, 2021
976bca9
resolve a bug
tchaton Feb 2, 2021
9dc38b3
add predict_loop
tchaton Feb 2, 2021
cf685d9
add predict loop on new accelerator
tchaton Feb 2, 2021
0c438f5
resolve a bug
tchaton Feb 2, 2021
ccf499c
add predict_loop
tchaton Feb 2, 2021
3fb57f3
clean predictloop
tchaton Feb 2, 2021
7e527f4
add predict loop on new accelerator
tchaton Feb 2, 2021
3575436
resolve a bug
tchaton Feb 2, 2021
c03e5bb
add predict_loop
tchaton Feb 2, 2021
a2f9f3c
add predict loop on new accelerator
tchaton Feb 2, 2021
4d24eff
resolve a bug
tchaton Feb 2, 2021
b155aab
resolve tests
tchaton Feb 2, 2021
a90a160
add predict method
awaelchli Feb 2, 2021
767bee0
add back commented accelerator code
awaelchli Feb 2, 2021
f771a7f
adapt test for sync_batch_norm to new plugin
awaelchli Feb 3, 2021
1a3b04e
fix deprecated tests
awaelchli Feb 3, 2021
a1f4938
fix ddp cpu choice when no num_processes are given
awaelchli Feb 3, 2021
38bc8b7
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 3, 2021
ce6b6de
yapf format
awaelchli Feb 3, 2021
3b7c20b
skip a memory test that cannot pass anymore
awaelchli Feb 3, 2021
609c848
remove sanetize
tchaton Feb 3, 2021
9a26178
rename train to run_train
tchaton Feb 3, 2021
f21780a
remove useless hooks
tchaton Feb 3, 2021
52b8db1
add misconfigurationException
tchaton Feb 3, 2021
b00c77c
remove wrong naming
tchaton Feb 3, 2021
864780f
resolve some legacy
tchaton Feb 3, 2021
ff7c50c
Merge branch 'accelerator-refactor-sharded' into introduce_predict_lo…
tchaton Feb 3, 2021
64e61a5
udpate docstring
tchaton Feb 3, 2021
9b8eed0
Merge branch 'introduce_predict_loop_1' of https://github.com/PyTorch…
tchaton Feb 3, 2021
f538c75
fix pickle error in spawn plugin
awaelchli Feb 3, 2021
b44d82e
x
awaelchli Feb 3, 2021
3820e77
avoid
awaelchli Feb 3, 2021
08ae327
x
awaelchli Feb 3, 2021
7d0e094
avoid tons of warnings from importing deprecated modules
awaelchli Feb 3, 2021
1028011
fix cyclic import in docs build
awaelchli Feb 3, 2021
11bd0d6
add support for sharded
justusschock Feb 4, 2021
6bf0b60
update typing
justusschock Feb 4, 2021
f94082b
add sharded and sharded_spawn to distributed types
justusschock Feb 4, 2021
7939b99
make unwrap model default
justusschock Feb 4, 2021
9131ffb
refactor LightningShardedDataParallel similar to LightningDistributed…
justusschock Feb 4, 2021
ed7425c
update sharded spawn to reflect changes
justusschock Feb 4, 2021
209a164
update sharded to reflect changes
justusschock Feb 4, 2021
837a070
Merge 1.1.5 changes
awaelchli Feb 4, 2021
136b321
fix merge
awaelchli Feb 4, 2021
ffcb535
fix merge
awaelchli Feb 4, 2021
1edfa73
yapf isort
awaelchli Feb 4, 2021
a689b81
merge 1.1.6
awaelchli Feb 4, 2021
330b14c
fix merge
awaelchli Feb 4, 2021
ef258d5
yapf isort
awaelchli Feb 4, 2021
c85000d
fix indentation in test
awaelchli Feb 4, 2021
5f3a35e
copy over reinit scheduler implementation from dev1.2
awaelchli Feb 4, 2021
fa1c9b7
fix apex tracking calls with dev_debugger
awaelchli Feb 5, 2021
e330a11
reduce diff to dev1.2, clean up
awaelchli Feb 5, 2021
994ac82
fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
awaelchli Feb 5, 2021
1a78601
sort plugin tests legacy/new
awaelchli Feb 6, 2021
4b76448
fix error handling for amp on cpu
awaelchli Feb 6, 2021
bfd54ab
Merge branch 'release/1.2-dev' into patch117
awaelchli Feb 6, 2021
0574d22
fix merge
awaelchli Feb 6, 2021
6ef6637
Merge branch 'patch117' into accelerator-refactor-sharded
awaelchli Feb 6, 2021
9feda39
[Feat] Resolve manual_backward (#5837)
tchaton Feb 6, 2021
7bb9d9f
fix tests/accelerator tests on cpu
awaelchli Feb 6, 2021
13ae1ff
[BugFix] Resolve manual optimization (#5852)
tchaton Feb 6, 2021
fc3b4db
Merge formatting changes from 1.2 branch
awaelchli Feb 6, 2021
b437642
Remove copy trainer parameters to happen earlier within the loop and …
SeanNaren Feb 7, 2021
8c6aa83
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
Feb 7, 2021
beb980a
resovle a bug
Feb 7, 2021
7a0fd27
Accelerator refactor sharded rpc (#5854)
justusschock Feb 7, 2021
0d0ced5
resolve bug
Feb 7, 2021
1f3ab76
fix assert in rpc test
awaelchli Feb 7, 2021
f1b1121
resolve a test
Feb 7, 2021
cd31fa1
fix docs compilation
awaelchli Feb 8, 2021
f48793e
accelerator refactor - fix for sharded parity test (#5866)
awaelchli Feb 8, 2021
81ff6ea
Remove DDP2 as this does not apply
Feb 8, 2021
20deb46
Add missing pre optimizer hook to ensure lambda closure is called
Feb 8, 2021
be4d1a2
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
Feb 8, 2021
0ac5fc4
fix apex docstring
awaelchli Feb 8, 2021
07fdd95
[accelerator][BugFix] Resolve some test for 1 gpu (#5863)
tchaton Feb 8, 2021
384b791
yapf isort
awaelchli Feb 8, 2021
b1a84b8
resolve flake8
tchaton Feb 8, 2021
a157a29
fix apex doctests
awaelchli Feb 8, 2021
08cfc65
fix apex doctests 2
awaelchli Feb 8, 2021
7888bfd
resolve docs
tchaton Feb 8, 2021
b5b4243
update drone
tchaton Feb 8, 2021
93ceb4c
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
tchaton Feb 8, 2021
d001bcf
clean env
Feb 8, 2021
ad47f47
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 8, 2021
60bfb1a
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 8, 2021
0608a41
update
Feb 8, 2021
f0120b5
update
Feb 8, 2021
bf8874e
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 8, 2021
baf7d7f
update
tchaton Feb 8, 2021
9360aad
update
tchaton Feb 8, 2021
b814cdc
merge
justusschock Feb 9, 2021
0d3ea37
Merge branch 'accelerator-refactor-sharded' of github.com:PytorchLigh…
justusschock Feb 9, 2021
f1f90c2
Fix RPC related tests, clean out old API, update for new accelerator …
SeanNaren Feb 9, 2021
6d05881
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
justusschock Feb 10, 2021
d86fdff
Update test_remove_1-4.py
justusschock Feb 10, 2021
5fbc1cf
Expose properties for tpu cores/gpus/num_gpus
Feb 10, 2021
aa9aea0
Add root GPU property
Feb 10, 2021
c35baf1
Move properties to properties.py
Feb 10, 2021
a9c6e21
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 10, 2021
8f3947b
move tests that were previously in drone
awaelchli Feb 10, 2021
50ecc4a
Fix root GPU property (#5908)
SeanNaren Feb 10, 2021
c7d0075
fix best model path transfer when no checkpoint callback available
awaelchli Feb 10, 2021
3f61d15
Merge remote-tracking branch 'original/accelerator-refactor-sharded' …
awaelchli Feb 10, 2021
061ea46
Fix setup hook order [wip] (#5858)
SeanNaren Feb 10, 2021
1fe1f91
rename ddp sequential -> rpc sequential for special test
awaelchli Feb 10, 2021
3683f5a
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 10, 2021
1f01b81
revert
awaelchli Feb 10, 2021
135c236
fix stupid merge problem
awaelchli Feb 10, 2021
222653d
Use property in connector for sampler (#5913)
SeanNaren Feb 10, 2021
f4311cd
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 11, 2021
b210dee
merge the import conflicts
awaelchli Feb 11, 2021
236009e
fix spawning of processes in slurm
awaelchli Feb 11, 2021
aace276
[wip] Fix some bugs for TPU [skip ci] (#5878)
tchaton Feb 11, 2021
68273f5
resolve some tests
Feb 11, 2021
ca77fa4
update
Feb 11, 2021
c35edfd
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
justusschock Feb 11, 2021
8cacef7
fix imports
justusschock Feb 11, 2021
f7bbe48
update
Feb 11, 2021
30d9800
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
25f7f13
resolve flake8
tchaton Feb 11, 2021
fa28c41
update azure pipeline
tchaton Feb 11, 2021
51c27e6
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 11, 2021
b888d68
skip a sharded test on cpu that requires a gpu
awaelchli Feb 11, 2021
01ca4cd
resolve tpus
Feb 11, 2021
181d143
Merge branch 'master' into accelerator-refactor-sharded
justusschock Feb 11, 2021
946a1e9
resolve bug
Feb 11, 2021
2ad1a6e
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
6e0aff0
resolve flake8
tchaton Feb 11, 2021
a931791
update
Feb 11, 2021
319d034
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
4117bec
updat utils
Feb 11, 2021
8d000f7
Merge branch 'master' into accelerator-refactor-sharded
tchaton Feb 11, 2021
0b1ba67
revert permission change on files
awaelchli Feb 11, 2021
cc385b4
suggestions from carlos
awaelchli Feb 11, 2021
e9eb318
remove unrelated formatting changes
awaelchli Feb 11, 2021
7c08400
remove incomplete comment
awaelchli Feb 11, 2021
7c3d184
Update pytorch_lightning/accelerators/__init__.py
awaelchli Feb 11, 2021
503426e
remove unrelated formatting change
awaelchli Feb 11, 2021
c0fbf7a
add types
awaelchli Feb 11, 2021
23a9a10
warn 1.7 ddp manual backward only if ddp kwarg unset
awaelchli Feb 11, 2021
a70ee4a
yapf + isort
awaelchli Feb 11, 2021
b0621c4
pep8 unused imports
awaelchli Feb 11, 2021
18bfe70
Merge branch 'master' into accelerator-refactor-sharded
awaelchli Feb 11, 2021
7b0515d
fix cyclic import in docs
awaelchli Feb 12, 2021
d966057
Apply suggestions from code review
Borda Feb 12, 2021
f636d9d
typer in accelerator.py
Borda Feb 12, 2021
5579ea7
typo
tchaton Feb 12, 2021
dea37de
Merge branch 'accelerator-refactor-sharded' into introduce_predict_lo…
tchaton Feb 12, 2021
57fe3dd
resolve flake8
tchaton Feb 12, 2021
ae0587f
update code
tchaton Feb 12, 2021
4583b5f
update
tchaton Feb 12, 2021
0db6c67
Update pytorch_lightning/trainer/predict_loop.py
tchaton Feb 12, 2021
1264dbd
Update pytorch_lightning/trainer/predict_loop.py
tchaton Feb 12, 2021
0f3c33e
Merge branch 'master' into introduce_predict_loop_1
awaelchli Feb 13, 2021
57eb883
fix merge
awaelchli Feb 13, 2021
990b307
fix merge
awaelchli Feb 13, 2021
24dcbea
reset legacy accelerator
awaelchli Feb 13, 2021
4a7953e
add missing rename dispatch
awaelchli Feb 13, 2021
cc35b09
rename post traning
awaelchli Feb 13, 2021
483b61e
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 13, 2021
5618318
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 13, 2021
5b18bfb
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Feb 13, 2021
deac085
Merge branch 'predict_loop_2' of https://github.com/PyTorchLightning/…
tchaton Feb 13, 2021
cd9fffc
update code
tchaton Feb 13, 2021
2de8fce
Merge branch 'predict_loop_2' of https://github.com/PyTorchLightning/…
tchaton Feb 13, 2021
6cca568
resolved comments
tchaton Feb 13, 2021
10eda2f
Merge branch 'predict_loop_2' into introduce_predict_loop_1
tchaton Feb 13, 2021
f4977aa
typo
tchaton Feb 13, 2021
52038c1
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 13, 2021
a77c4e8
typo
tchaton Feb 13, 2021
cdfa212
Merge branch 'introduce_predict_loop_1' of https://github.com/PyTorch…
tchaton Feb 13, 2021
b4a3884
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 13, 2021
00f9b4e
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 13, 2021
5c3a87d
add flow description
tchaton Feb 13, 2021
33654be
Merge branch 'introduce_predict_loop_1' of https://github.com/PyTorch…
tchaton Feb 13, 2021
73bfc4c
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 14, 2021
3ee31e1
resolve comments
tchaton Feb 15, 2021
3984dbc
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 15, 2021
25a7e6f
update on comments
tchaton Feb 15, 2021
35f90b2
update flow
tchaton Feb 15, 2021
b477665
add backticks
tchaton Feb 15, 2021
8a3533c
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 15, 2021
9210535
resolve tpu
tchaton Feb 15, 2021
1f9c9e5
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 15, 2021
e870159
Merge branch 'master' into introduce_predict_loop_1
tchaton Feb 15, 2021
5e06d12
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 15, 2021
5bcff93
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 15, 2021
7c331c0
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 16, 2021
8cc0bf3
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 16, 2021
4a3e277
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 16, 2021
0cba060
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 16, 2021
fd08085
Merge branch 'master' into introduce_predict_loop_1
mergify[bot] Feb 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))


- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752))


- Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706))


Expand Down
33 changes: 24 additions & 9 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ def training_step(self, args):

args[0] = batch

with self.precision_plugin.train_step_context():
with self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)

def post_training_step(self):
self.training_type_plugin.post_training_step()
Expand All @@ -162,9 +161,8 @@ def validation_step(self, args):

args[0] = batch

with self.precision_plugin.val_step_context():
with self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)

def test_step(self, args):
"""The actual test step.
Expand All @@ -181,9 +179,26 @@ def test_step(self, args):

args[0] = batch

with self.precision_plugin.test_step_context():
with self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)

def predict(self, args):
"""The actual predict step.

Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
batch = self.to_device(args[0])

args[0] = batch

with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)

def training_step_end(self, output):
"""A hook to do something at the end of the training step
Expand Down
63 changes: 56 additions & 7 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self):
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0

@property
def trainer(self):
Expand Down Expand Up @@ -96,6 +97,14 @@ def test_batch_idx(self) -> int:
"""
return self._test_batch_idx

@property
def predict_batch_idx(self) -> int:
"""
The current batch index being processed during predicting.
Use this to update your progress bar.
"""
return self._predict_batch_idx

@property
def total_train_batches(self) -> int:
"""
Expand All @@ -108,7 +117,7 @@ def total_train_batches(self) -> int:
@property
def total_val_batches(self) -> int:
"""
The total number of training batches during validation, which may change from epoch to epoch.
The total number of validation batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
Expand All @@ -121,12 +130,21 @@ def total_val_batches(self) -> int:
@property
def total_test_batches(self) -> int:
"""
The total number of training batches during testing, which may change from epoch to epoch.
The total number of testing batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)

@property
def total_predict_batches(self) -> int:
"""
The total number of predicting batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
predict dataloader is of infinite size.
"""
return sum(self.trainer.num_predict_batches)

def disable(self):
"""
You should provide a way to disable the progress bar.
Expand Down Expand Up @@ -168,6 +186,12 @@ def on_test_start(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1

def on_predict_start(self, trainer, pl_module):
self._predict_batch_idx = 0

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1


class ProgressBar(ProgressBarBase):
r"""
Expand Down Expand Up @@ -282,6 +306,20 @@ def init_train_tqdm(self) -> tqdm:
)
return bar

def init_predict_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for predicting. """
bar = tqdm(
desc='Predicting',
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
bar = tqdm(
Expand All @@ -294,12 +332,10 @@ def init_validation_tqdm(self) -> tqdm:
)
return bar

def init_test_tqdm(self, trainer=None) -> tqdm:
def init_test_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
bar = tqdm(
desc=desc,
desc="Testing",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
Expand Down Expand Up @@ -365,7 +401,7 @@ def on_train_end(self, trainer, pl_module):

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand All @@ -377,6 +413,19 @@ def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self.predict_progress_bar = self.init_predict_tqdm()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self._update_bar(self.predict_progress_bar)

def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass
Expand Down
37 changes: 34 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,23 @@ def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader
"""
# do something when the batch ends

def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
self.eval()

def on_test_model_train(self) -> None:
def on_predict_model_eval(self) -> None:
"""
Sets the model to train during the test loop
Sets the model to eval during the predict loop
"""
self.train()
self.eval()

def on_epoch_start(self) -> None:
"""
Expand Down Expand Up @@ -518,6 +524,31 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for prediction.

It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.

- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`

Note:
Lightning adds the correct sampler for distributed and arbitrary hardware
There is no need to set it yourself.

Return:
Single or multiple PyTorch DataLoaders.

Note:
In the case where you return multiple prediction dataloaders, the :meth:`predict`
will have an argument ``dataloader_idx`` which matches the order here.
"""

def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ def forward(self, *inputs, **kwargs):
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:

elif running_stage == RunningStage.PREDICTING:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")

else:
output = self.module(*inputs, **kwargs)

return output


Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def connect(
Will be called by the accelerator.
"""

def pre_training(self) -> None:
"""Hook to do something before the training starts."""
def pre_dispatch(self) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Hook to do something before the training/evaluation/prediction starts."""

def post_training(self) -> None:
"""Hook to do something after the training finishes."""
def post_dispatch(self) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Hook to do something after the training/evaluation/prediction finishes."""

@contextlib.contextmanager
def train_step_context(self) -> Generator:
Expand All @@ -53,3 +53,8 @@ def val_step_context(self) -> Generator:
def test_step_context(self) -> Generator:
"""A contextmanager for the teststep"""
yield

@contextlib.contextmanager
def predict_context(self) -> Generator:
"""A contextmanager for the predict step"""
yield
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def pre_training(self):
def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
Expand All @@ -232,7 +232,7 @@ def pre_training(self):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

Expand All @@ -257,7 +257,7 @@ def pre_training(self):

self.barrier()

def post_training(self):
def post_dispatch(self):
if "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def start_training(self, trainer):
def start_testing(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def start_predicting(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue

Expand All @@ -128,7 +131,7 @@ def new_process(self, process_idx, trainer, mp_queue):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

Expand All @@ -153,15 +156,12 @@ def new_process(self, process_idx, trainer, mp_queue):

self.barrier()

if trainer.testing:
results = trainer.run_test()
else:
results = trainer.train()
results = trainer.train_or_test_or_predict()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)

def post_training(self):
def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
Expand Down
Loading