Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug comparing max_steps to global step which inits at 0 #4278

Merged
merged 10 commits into from
Oct 22, 2020
23 changes: 16 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,10 @@ def run_training_epoch(self):

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
break
accumulation_done = self.accumulated_batches_reached()
# Ensure accumulation across batches has completed before breaking loop
if accumulation_done:
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
Expand Down Expand Up @@ -648,8 +651,8 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

# checks if backward or backward + optimizer step (via closure)
accumulation_done = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
is_final_batch = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
accumulation_done = self.accumulated_batches_reached()
is_final_batch = self.num_training_batches_reached()

# lightning module hook
splits = self.tbptt_split_batch(batch)
Expand Down Expand Up @@ -822,8 +825,8 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
num_accumulated_batches_reached = self.accumulated_batches_reached()
num_training_batches_reached = self.num_training_batches_reached()

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
Expand All @@ -834,13 +837,19 @@ def run_on_epoch_end_hook(self, epoch_output):
self.trainer.call_hook("on_train_epoch_end", epoch_output)

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
num_accumulated_batches_reached = self.accumulated_batches_reached()
num_training_batches_reached = self.num_training_batches_reached()

# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step += 1

def accumulated_batches_reached(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have it rather as protected as it is not meant to be used by the user...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make a followup PR

return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
Expand Down
13 changes: 5 additions & 8 deletions tests/trainer/optimization/test_backward_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def test_backward_count_with_grad_accumulation(torch_backward):

torch_backward.reset_mock()

# TODO: this test fails on master, max_steps seems to fail together with accumulation
# trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
# trainer.fit(model)
# assert torch_backward.call_count == 6
trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
trainer.fit(model)
assert torch_backward.call_count == 12


@patch("torch.Tensor.backward")
Expand All @@ -48,8 +47,6 @@ def test_backward_count_with_closure(torch_backward):

torch_backward.reset_mock()

# TODO: max_steps seems to fail together with accumulation
# trainer = Trainer(max_steps=5, accumulate_grad_batches=2
trainer = Trainer(max_epochs=1, limit_train_batches=5, accumulate_grad_batches=2)
trainer = Trainer(max_steps=5, accumulate_grad_batches=2)
trainer.fit(model)
assert torch_backward.call_count == 5
assert torch_backward.call_count == 10
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 5 additions & 2 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def test_accumulation_and_early_stopping(tmpdir):
lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None)
after_lr = lrfinder.suggestion()

expected_num_lrs = 100
expected_batch_idx = 200 - 1

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 99, \
assert len(lrfinder.results['lr']) == expected_num_lrs, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 99 * 2, \
assert lrfinder._total_batch_idx == expected_batch_idx, \
'Accumulation parameter did not work'


Expand Down
20 changes: 20 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,26 @@ def test_trainer_min_steps_and_epochs(tmpdir):
), "Model did not train for at least min_steps"


def test_trainer_max_steps_accumulate_batches(tmpdir):
"""Verify model trains according to specified max steps with grad accumulated batches"""
model, trainer_options, num_train_samples = _init_steps_model()

# define less train steps than epochs
trainer_options.update(
default_root_dir=tmpdir,
max_steps=(num_train_samples + 10),
accumulate_grad_batches=10,
)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check training stopped at max_steps
assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps"


def test_benchmark_option(tmpdir):
"""Verify benchmark option."""

Expand Down