diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index a6ee9e4d45583f..179eafdc139a16 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -109,9 +109,12 @@ def predict_dataloader(self): def cli_main(): - cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True) - cli.trainer.test(cli.model, datamodule=cli.datamodule) - predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule) + cli = LightningCLI( + LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False + ) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best") + predictions = cli.trainer.predict(ckpt_path="best") print(predictions[0]) diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 0f2e39f0c88268..26d83522ce6f39 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -124,9 +124,10 @@ def predict_dataloader(self): def cli_main(): - cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True) - cli.trainer.test(cli.model, datamodule=cli.datamodule) - predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule) + cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best") + predictions = cli.trainer.predict(ckpt_path="best") print(predictions[0]) diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 9a0b6e002ffa71..8eda150cbb620d 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -198,8 +198,9 @@ def cli_main(): if not _DALI_AVAILABLE: return - cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True) - cli.trainer.test(cli.model, datamodule=cli.datamodule) + cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best") if __name__ == "__main__": diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 5fd13723d66536..5fdcb8d8c3bb2b 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -72,8 +72,11 @@ def configure_optimizers(self): def cli_main(): - cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True) - cli.trainer.test(cli.model, datamodule=cli.datamodule) + cli = LightningCLI( + LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False + ) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best") if __name__ == "__main__": diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 466730990f67b1..6a73dc5ee3b919 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser): } ) - def instantiate_trainer(self): - finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"]) + def instantiate_trainer(self, *args): + finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning")) self.trainer_defaults["callbacks"] = [finetuning_callback] - super().instantiate_trainer() + return super().instantiate_trainer(*args) def cli_main(): diff --git a/pl_examples/run_examples.sh b/pl_examples/run_examples.sh index 66714df6727d9c..4dfad17ae8970a 100644 --- a/pl_examples/run_examples.sh +++ b/pl_examples/run_examples.sh @@ -2,7 +2,12 @@ set -ex dir_path=$(dirname "${BASH_SOURCE[0]}") -args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2" +args="--trainer.max_epochs=1 " \ + "--data.batch_size=32 " \ + "--trainer.limit_train_batches=2 " \ + "--trainer.limit_val_batches=2 " \ + "--trainer.limit_test_batches=2 "\ + "--trainer.limit_predict_batches=2" python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@" python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@" diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 2c9397125e751b..223f3f02d2c60b 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -87,7 +87,7 @@ fi # report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n" # test that a user can manually launch individual processes -args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1" +args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --limit_test_batches=1" MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} & MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/simple_image_classifier.py ${args} report+="Ran\tmanual ddp launch test\n"