Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] 1/2 Add trainer.predict #5579

Merged
merged 60 commits into from
Jan 27, 2021
Merged

[feat] 1/2 Add trainer.predict #5579

merged 60 commits into from
Jan 27, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Jan 19, 2021

What does this PR do?

This PR:

By calling Trainer.predict(model, dataloaders), it will call forward function and gather predictions.

Trainer

  • predict

LightningModule

  • predict -> calls forward

Accelerators

  • predict

Uniformize use of RunningState across LoggerConnector, DDPWrapper, DP.

Add tests for DDP, DP, DDP_SPAWN, 1 GPU, DDP_CPU, DDP_SHARDED.

Fixes # (issue) <- this links related issue to this PR

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified
  • Check that target branch and milestone match!

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton added this to the 1.2 milestone Jan 19, 2021
@pep8speaks
Copy link

pep8speaks commented Jan 19, 2021

Hello @tchaton! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-01-27 15:21:22 UTC

@tchaton tchaton closed this Jan 19, 2021
@tchaton tchaton deleted the feat_predict branch January 20, 2021 08:12
@tchaton tchaton restored the feat_predict branch January 20, 2021 08:13
@tchaton tchaton reopened this Jan 20, 2021
@tchaton tchaton marked this pull request as ready for review January 20, 2021 10:08
@tchaton tchaton self-assigned this Jan 20, 2021
@tchaton tchaton added feature Is an improvement or enhancement priority: 0 High priority task labels Jan 20, 2021
Copy link
Contributor

@SeanNaren SeanNaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic wise everything looks clear to me! Just a few nits that are not super critical :)

As mentioned, it would be nice to display predict with just dataloader rather than test_dataloader to differentiate that this works on any dataloader (I can pass my train dataloader for example)

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could somehow make the logic that test_step is using predcit_step?

@tchaton tchaton enabled auto-merge (squash) January 27, 2021 05:28
@mergify mergify bot removed the has conflicts label Jan 27, 2021
@tchaton tchaton disabled auto-merge January 27, 2021 06:31
@tchaton tchaton enabled auto-merge (squash) January 27, 2021 06:31
@tchaton tchaton disabled auto-merge January 27, 2021 07:57
@tchaton tchaton enabled auto-merge (squash) January 27, 2021 07:57
@s-rog
Copy link
Contributor

s-rog commented Jan 27, 2021

Does predict() deal with ddp data padding from distributed samplers?
If not, I suppose we'll need a disclaimer that we're waiting on a pytorch fix

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just check the new _predition attribute

docs/source/starter/introduction_guide.rst Show resolved Hide resolved
pytorch_lightning/callbacks/progress.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
tests/overrides/test_data_parallel.py Show resolved Hide resolved
@Borda Borda added ready PRs ready to be merged design Includes a design discussion labels Jan 27, 2021
@carmocca carmocca mentioned this pull request Jan 27, 2021
12 tasks
@mergify mergify bot removed the has conflicts label Jan 27, 2021
@tchaton tchaton merged commit 3da28fd into release/1.2-dev Jan 27, 2021
@tchaton tchaton deleted the feat_predict branch January 27, 2021 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement priority: 0 High priority task ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.