diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index aac9977bd57d9..d71f44cec291f 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -33,7 +33,7 @@ def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None, close_after_fit: Optional[bool] = True, - offline_mode: bool = True, + offline_mode: bool = False, experiment_name: Optional[str] = None, upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None, @@ -140,7 +140,7 @@ def any_lightning_module_function_or_hook(...): "namespace/project_name" for example "tom/minst-classification". If None, the value of NEPTUNE_PROJECT environment variable will be taken. You need to create the project in https://neptune.ai first. - offline_mode: Optional default True. If offline_mode=True no logs will be send + offline_mode: Optional default False. If offline_mode=True no logs will be send to neptune. Usually used for debug and test purposes. close_after_fit: Optional default True. If close_after_fit=False the experiment will not be closed after training and additional metrics, diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 0065b358e9e80..3ae86495e92e3 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -10,6 +10,15 @@ from tests.base import LightningTestModel +def _get_logger_args(logger_class, save_dir): + logger_args = {} + if 'save_dir' in inspect.getfullargspec(logger_class).args: + logger_args.update(save_dir=str(save_dir)) + if 'offline_mode' in inspect.getfullargspec(logger_class).args: + logger_args.update(offline_mode=True) + return logger_args + + @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, @@ -40,10 +49,8 @@ def log_metrics(self, metrics, step): super().log_metrics(metrics, step) self.history.append((step, metrics)) - if 'save_dir' in inspect.getfullargspec(logger_class).args: - logger = StoreHistoryLogger(save_dir=str(tmpdir)) - else: - logger = StoreHistoryLogger() + logger_args = _get_logger_args(logger_class, tmpdir) + logger = StoreHistoryLogger(**logger_args) trainer = Trainer( max_epochs=1, @@ -80,10 +87,8 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): import atexit monkeypatch.setattr(atexit, 'register', lambda _: None) - if 'save_dir' in inspect.getfullargspec(logger_class).args: - logger = logger_class(save_dir=str(tmpdir)) - else: - logger = logger_class() + logger_args = _get_logger_args(logger_class, tmpdir) + logger = logger_class(**logger_args) trainer = Trainer( max_epochs=1,