Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #2972 #2946 #2986

Merged
merged 11 commits into from
Aug 15, 2020
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def add_progress_bar_metrics(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def log_metrics(self, *args):
def log_metrics(self, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down Expand Up @@ -379,7 +379,7 @@ def _evaluate(

dl_outputs.append(output)

self.__eval_add_step_metrics(output)
self.__eval_add_step_metrics(output, batch_idx)

# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
Expand Down Expand Up @@ -505,14 +505,19 @@ def __gather_epoch_end_eval_results(self, outputs):
eval_results = eval_results[0]
return eval_results

def __eval_add_step_metrics(self, output):
def __eval_add_step_metrics(self, output, batch_idx):
# track step level metrics
if isinstance(output, EvalResult) and not self.running_sanity_check:
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics

if len(step_log_metrics) > 0:
self.log_metrics(step_log_metrics, {})
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v

self.log_metrics(metrics_by_epoch, {}, step=batch_idx)

if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):

if "step" in scalar_metrics and step is None:
step = scalar_metrics.pop("step")
else:

elif step is None:
# added metrics by Lightning for convenience
scalar_metrics['epoch'] = self.current_epoch
step = step if step is not None else self.global_step

# log actual metrics
if self.is_global_zero and self.logger is not None:
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
Expand Down
18 changes: 12 additions & 6 deletions tests/trainer/test_validation_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,15 @@ def test_val_step_only_step_metrics(tmpdir):

# make sure we logged the correct epoch metrics
total_empty_epoch_metrics = 0
epoch = 0
for metric in trainer.dev_debugger.logged_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_pbar_acc' not in metric
assert metric['val_step_log_acc']
assert metric['val_step_log_pbar_acc']
assert metric[f'val_step_log_acc/epoch_{epoch}']
assert metric[f'val_step_log_pbar_acc/epoch_{epoch}']
else:
total_empty_epoch_metrics += 1

Expand All @@ -228,6 +231,8 @@ def test_val_step_only_step_metrics(tmpdir):
# make sure we logged the correct epoch pbar metrics
total_empty_epoch_metrics = 0
for metric in trainer.dev_debugger.pbar_added_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_log_acc' not in metric
Expand Down Expand Up @@ -288,11 +293,12 @@ def test_val_step_epoch_step_metrics(tmpdir):
for metric_idx in range(0, len(trainer.dev_debugger.logged_metrics), batches + 1):
batch_metrics = trainer.dev_debugger.logged_metrics[metric_idx: metric_idx + batches]
epoch_metric = trainer.dev_debugger.logged_metrics[metric_idx + batches]
epoch = epoch_metric['epoch']

# make sure the metric was split
for batch_metric in batch_metrics:
assert 'step_val_step_log_acc' in batch_metric
assert 'step_val_step_log_pbar_acc' in batch_metric
assert f'step_val_step_log_acc/epoch_{epoch}' in batch_metric
assert f'step_val_step_log_pbar_acc/epoch_{epoch}' in batch_metric

# make sure the epoch split was correct
assert 'epoch_val_step_log_acc' in epoch_metric
Expand Down Expand Up @@ -421,11 +427,11 @@ def test_val_step_full_loop_result_dp(tmpdir):
assert 'train_step_metric' in seen_keys
assert 'train_step_end_metric' in seen_keys
assert 'epoch_train_epoch_end_metric' in seen_keys
assert 'step_validation_step_metric' in seen_keys
assert 'step_validation_step_metric/epoch_0' in seen_keys
assert 'epoch_validation_step_metric' in seen_keys
assert 'validation_step_end_metric' in seen_keys
assert 'validation_epoch_end_metric' in seen_keys
assert 'step_test_step_metric' in seen_keys
assert 'step_test_step_metric/epoch_2' in seen_keys
assert 'epoch_test_step_metric' in seen_keys
assert 'test_step_end_metric' in seen_keys
assert 'test_epoch_end_metric' in seen_keys