Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into fix/update-mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 24, 2021
2 parents 980c4bd + f0fafa2 commit 8a8d3fd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
44 changes: 26 additions & 18 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ class WandbLogger(LightningLoggerBase):
Args:
name: Display name for the run.
save_dir: Path where data is saved.
save_dir: Path where data is saved (wandb dir by default).
offline: Run offline (data can be streamed later to wandb servers).
id: Sets the version, mainly used to resume a previous run.
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
version: Sets the version, mainly used to resume a previous run.
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
experiment: WandB experiment object.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Example::
Example:
.. code-block:: python
Expand All @@ -74,9 +75,9 @@ class WandbLogger(LightningLoggerBase):
make sure to use `commit=False` so the logging step does not increase.
See Also:
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
on how to use W&B with Pytorch Lightning.
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
on how to use W&B with PyTorch Lightning
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
"""

Expand All @@ -86,14 +87,15 @@ def __init__(
self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
offline: Optional[bool] = False,
id: Optional[str] = None,
anonymous: bool = False,
anonymous: Optional[bool] = False,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: bool = False,
log_model: Optional[bool] = False,
experiment=None,
prefix: str = '',
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
**kwargs
):
if wandb is None:
Expand All @@ -102,13 +104,14 @@ def __init__(
super().__init__()
self._name = name
self._save_dir = save_dir
self._anonymous = 'allow' if anonymous else None
self._offline = offline
self._id = version or id
self._anonymous = 'allow' if anonymous else None
self._project = project
self._experiment = experiment
self._offline = offline
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
Expand Down Expand Up @@ -164,11 +167,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if step is not None and step + self._step_offset < self.experiment.step:
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
)
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.')
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
else:
self.experiment.log(metrics)

@property
def save_dir(self) -> Optional[str]:
Expand Down
12 changes: 12 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})

# mock wandb step
wandb.init().step = 0

Expand Down

0 comments on commit 8a8d3fd

Please sign in to comment.