From 89d061da8bdc2d37adea97096d1190c8698c2fa3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 16:49:56 -0400 Subject: [PATCH 1/5] add tests for single scalar return from training --- tests/base/deterministic_model.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index 529d64f799fcd..ac3aaa6856e54 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -52,6 +52,44 @@ def count_num_graphs(self, result, num_graphs=0): return num_graphs + # --------------------------- + # scalar return + # --------------------------- + def training_step_scalar_return(self, batch, batch_idx): + # TODO: verify + acc = self.step(batch, batch_idx) + self.training_step_called = True + return acc + + def training_step_end_scalar(self, output): + # TODO: verify + self.training_step_end_called = True + + # make sure loss has the grad + assert 'loss' in output + assert output['loss'].grad_fn is not None + + # make sure nothing else has grads + assert self.count_num_graphs(output) == 1 + + return output + + def training_epoch_end_scalar(self, outputs): + # TODO: verify + self.training_epoch_end_called = True + + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 4 batches + assert len(outputs) == 4 + for batch_out in outputs: + # TODO: verify + assert batch_out == (42.0 * 3) + (15.0 * 3) + + prototype_loss = outputs[0]['loss'] + return prototype_loss + # -------------------------- # dictionary returns # -------------------------- From 33367e2600ea2c8a1918c16b00da04f45fcfadd4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 17:07:37 -0400 Subject: [PATCH 2/5] add tests for single scalar return from training --- pytorch_lightning/trainer/logging.py | 11 ++++++ pytorch_lightning/trainer/training_loop.py | 5 ++- tests/trainer/test_trainer_steps.py | 39 +++++++++++++++++++++- 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 5349849e09b89..35f5d5d35b9ca 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -98,6 +98,17 @@ def process_output(self, output, train=False): Separates loss from logging and progress bar metrics """ + # -------------------------- + # handle single scalar only + # -------------------------- + # single scalar returned from a xx_step + if isinstance(output, torch.Tensor): + progress_bar_metrics = {} + log_metrics = {} + callback_metrics = {} + hiddens = None + return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens + # --------------- # EXTRACT CALLBACK KEYS # --------------- diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 008faa20eebcf..fa493f2e1b09a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -792,7 +792,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs - training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) + if isinstance(training_step_output_for_epoch_end, torch.Tensor): + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + else: + training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 6091f486257c7..c9b6344bb282b 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -4,7 +4,6 @@ import torch -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_training_step_dict(tmpdir): """ Tests that only training_step can be used @@ -156,3 +155,41 @@ def test_train_step_epoch_end(tmpdir): pbar_metrics = train_step_end_out['progress_bar'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 + + +def test_training_step_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 \ No newline at end of file From 7567ec951fcbe48037ab8d7deddbd0971aaa539c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 17:07:47 -0400 Subject: [PATCH 3/5] add tests for single scalar return from training --- tests/trainer/test_trainer_steps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index c9b6344bb282b..540d735bbf2ff 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -192,4 +192,4 @@ def test_training_step_scalar(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) - assert opt_closure_result['loss'].item() == 171 \ No newline at end of file + assert opt_closure_result['loss'].item() == 171 From 899d83f3ea5c4329453051deb9794944c59c7274 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 17:18:17 -0400 Subject: [PATCH 4/5] add tests for single scalar return from training --- tests/base/deterministic_model.py | 21 +++-- tests/trainer/test_trainer_steps.py | 123 +++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index ac3aaa6856e54..c387997da57d7 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -56,26 +56,28 @@ def count_num_graphs(self, result, num_graphs=0): # scalar return # --------------------------- def training_step_scalar_return(self, batch, batch_idx): - # TODO: verify acc = self.step(batch, batch_idx) self.training_step_called = True return acc def training_step_end_scalar(self, output): - # TODO: verify self.training_step_end_called = True # make sure loss has the grad - assert 'loss' in output - assert output['loss'].grad_fn is not None + assert isinstance(output, torch.Tensor) + assert output.grad_fn is not None # make sure nothing else has grads - assert self.count_num_graphs(output) == 1 + assert self.count_num_graphs({'loss': output}) == 1 + + assert output == 171 return output def training_epoch_end_scalar(self, outputs): - # TODO: verify + """ + There should be an array of scalars without graphs that are all 171 (4 of them) + """ self.training_epoch_end_called = True if self.use_dp or self.use_ddp2: @@ -84,10 +86,11 @@ def training_epoch_end_scalar(self, outputs): # only saw 4 batches assert len(outputs) == 4 for batch_out in outputs: - # TODO: verify - assert batch_out == (42.0 * 3) + (15.0 * 3) + assert batch_out == 171 + assert batch_out.grad_fn is None + assert isinstance(batch_out, torch.Tensor) - prototype_loss = outputs[0]['loss'] + prototype_loss = outputs[0] return prototype_loss # -------------------------- diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 540d735bbf2ff..9dd12038c0831 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -159,7 +159,7 @@ def test_train_step_epoch_end(tmpdir): def test_training_step_scalar(tmpdir): """ - Tests that only training_step can be used + Tests that only training_step that returns a single scalar can be used """ model = DeterministicModel() model.training_step = model.training_step_scalar_return @@ -193,3 +193,124 @@ def test_training_step_scalar(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171 + + +def training_step_scalar_with_step_end(tmpdir): + """ + Checks train_step with scalar only + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = model.training_step_end_scalar + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 + + +def test_full_training_loop_scalar(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + (all with scalar return from train_step) + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = model.training_step_end_scalar + model.training_epoch_end = model.training_epoch_end_scalar + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 + + +def test_train_step_epoch_end_scalar(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + (with scalar return) + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end_scalar + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 From 76c5e95701db487751f8f197287557f8d59b60b0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 17:26:46 -0400 Subject: [PATCH 5/5] add tests for single scalar return from training --- .../trainer/test_trainer_steps_dict_return.py | 158 ++++++++++++++++++ ...py => test_trainer_steps_scalar_return.py} | 157 +---------------- 2 files changed, 161 insertions(+), 154 deletions(-) create mode 100644 tests/trainer/test_trainer_steps_dict_return.py rename tests/trainer/{test_trainer_steps.py => test_trainer_steps_scalar_return.py} (53%) diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/test_trainer_steps_dict_return.py new file mode 100644 index 0000000000000..290983fbf6a5c --- /dev/null +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -0,0 +1,158 @@ +""" +Tests to ensure that the training loop works with a dict +""" +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel + + +def test_training_step_dict(tmpdir): + """ + Tests that only training_step can be used + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 + + train_step_out = out.training_step_output_for_epoch_end + pbar_metrics = train_step_out['progress_bar'] + assert 'log' in train_step_out + assert 'progress_bar' in train_step_out + assert train_step_out['train_step_test'] == 549 + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def training_step_with_step_end(tmpdir): + """ + Checks train_step + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 14.0 + assert out.batch_log_metrics['log_acc2'] == 9.0 + + train_step_end_out = out.training_step_output_for_epoch_end + pbar_metrics = train_step_end_out['progress_bar'] + assert 'train_step_end' in train_step_end_out + assert pbar_metrics['pbar_acc1'] == 19.0 + assert pbar_metrics['pbar_acc2'] == 21.0 + + +def test_full_training_loop_dict(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + """ + model = DeterministicModel() + model.training_step = model.training_step_for_step_end_dict + model.training_step_end = model.training_step_end_dict + model.training_epoch_end = model.training_epoch_end_dict + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 14.0 + assert out.batch_log_metrics['log_acc2'] == 9.0 + + train_step_end_out = out.training_step_output_for_epoch_end + pbar_metrics = train_step_end_out['progress_bar'] + assert pbar_metrics['pbar_acc1'] == 19.0 + assert pbar_metrics['pbar_acc2'] == 21.0 + + +def test_train_step_epoch_end(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end_dict + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert trainer.callback_metrics['epoch_end_log_1'] == 178 + assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 + + train_step_end_out = out.training_step_output_for_epoch_end + pbar_metrics = train_step_end_out['progress_bar'] + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps_scalar_return.py similarity index 53% rename from tests/trainer/test_trainer_steps.py rename to tests/trainer/test_trainer_steps_scalar_return.py index 9dd12038c0831..b893b58310dc3 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -1,162 +1,11 @@ +""" +Tests to ensure that the training loop works with a scalar +""" from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel -import pytest import torch -def test_training_step_dict(tmpdir): - """ - Tests that only training_step can be used - """ - model = DeterministicModel() - model.training_step = model.training_step_dict_return - model.val_dataloader = None - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - weights_summary=None, - ) - trainer.fit(model) - - # make sure correct steps were called - assert model.training_step_called - assert not model.training_step_end_called - assert not model.training_epoch_end_called - - # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - - out = trainer.run_training_batch(batch, batch_idx) - assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 - - train_step_out = out.training_step_output_for_epoch_end - pbar_metrics = train_step_out['progress_bar'] - assert 'log' in train_step_out - assert 'progress_bar' in train_step_out - assert train_step_out['train_step_test'] == 549 - assert pbar_metrics['pbar_acc1'] == 17.0 - assert pbar_metrics['pbar_acc2'] == 19.0 - - # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) - assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) - - -def training_step_with_step_end(tmpdir): - """ - Checks train_step + training_step_end - """ - model = DeterministicModel() - model.training_step = model.training_step_for_step_end_dict - model.training_step_end = model.training_step_end_dict - model.val_dataloader = None - - trainer = Trainer(fast_dev_run=True, weights_summary=None) - trainer.fit(model) - - # make sure correct steps were called - assert model.training_step_called - assert model.training_step_end_called - assert not model.training_epoch_end_called - - # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - - out = trainer.run_training_batch(batch, batch_idx) - assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 - - train_step_end_out = out.training_step_output_for_epoch_end - pbar_metrics = train_step_end_out['progress_bar'] - assert 'train_step_end' in train_step_end_out - assert pbar_metrics['pbar_acc1'] == 19.0 - assert pbar_metrics['pbar_acc2'] == 21.0 - - -def test_full_training_loop_dict(tmpdir): - """ - Checks train_step + training_step_end + training_epoch_end - """ - model = DeterministicModel() - model.training_step = model.training_step_for_step_end_dict - model.training_step_end = model.training_step_end_dict - model.training_epoch_end = model.training_epoch_end_dict - model.val_dataloader = None - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model) - - # make sure correct steps were called - assert model.training_step_called - assert model.training_step_end_called - assert model.training_epoch_end_called - - # assert epoch end metrics were added - assert trainer.callback_metrics['epoch_end_log_1'] == 178 - assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 - - # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - - out = trainer.run_training_batch(batch, batch_idx) - assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 - - train_step_end_out = out.training_step_output_for_epoch_end - pbar_metrics = train_step_end_out['progress_bar'] - assert pbar_metrics['pbar_acc1'] == 19.0 - assert pbar_metrics['pbar_acc2'] == 21.0 - - -def test_train_step_epoch_end(tmpdir): - """ - Checks train_step + training_epoch_end (NO training_step_end) - """ - model = DeterministicModel() - model.training_step = model.training_step_dict_return - model.training_step_end = None - model.training_epoch_end = model.training_epoch_end_dict - model.val_dataloader = None - - trainer = Trainer(max_epochs=1, weights_summary=None) - trainer.fit(model) - - # make sure correct steps were called - assert model.training_step_called - assert not model.training_step_end_called - assert model.training_epoch_end_called - - # assert epoch end metrics were added - assert trainer.callback_metrics['epoch_end_log_1'] == 178 - assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 - - # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break - - out = trainer.run_training_batch(batch, batch_idx) - assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 - - train_step_end_out = out.training_step_output_for_epoch_end - pbar_metrics = train_step_end_out['progress_bar'] - assert pbar_metrics['pbar_acc1'] == 17.0 - assert pbar_metrics['pbar_acc2'] == 19.0 - - def test_training_step_scalar(tmpdir): """ Tests that only training_step that returns a single scalar can be used