Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[train] New training options for logging/validation based on number of steps #3379

Merged
merged 12 commits into from
Mar 8, 2021
16 changes: 16 additions & 0 deletions parlai/agents/test_agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def eval_step(self, batch):
Null output.
"""
return Output()


class MockTrainUpdatesAgent(MockTorchAgent):
"""
Simulate training updates.
"""

def train_step(self, batch):
ret = super().train_step(batch)
update_freq = self.opt.get('update_freq', 1)
if update_freq == 1:
self._number_training_updates += 1
else:
self._number_grad_accum = (self._number_grad_accum + 1) % update_freq
self._number_training_updates += int(self._number_grad_accum == 0)
return ret
10 changes: 8 additions & 2 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def add_cmdline_args(
'--max-lr-steps',
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=-1,
help='Number of train steps the scheduler should take after warmup. '
hidden=True,
help='**DEPRECATED: please use --max-train-steps instead**'
'Number of train steps the scheduler should take after warmup. '
'Training is terminated after this many steps. This should only be '
'set for --lr-scheduler cosine or linear',
)
Expand Down Expand Up @@ -222,7 +224,11 @@ def lr_scheduler_factory(cls, opt, optimizer, states, hard_reset=False):
decay = opt.get('lr_scheduler_decay', 0.5)
warmup_updates = opt.get('warmup_updates', -1)
warmup_rate = opt.get('warmup_rate', 1e-4)
max_lr_steps = opt.get('max_lr_steps', -1)
max_lr_steps = opt.get('max_train_steps', -1)
if opt.get('max_lr_steps', -1) > 0:
raise ValueError(
'--max-lr-steps is **DEPRECATED**; please set --max-train-steps directly'
)
invsqrt_lr_decay_gamma = opt.get('invsqrt_lr_decay_gamma', -1)

if opt.get('lr_scheduler') == 'none':
Expand Down
113 changes: 85 additions & 28 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,21 @@ def setup_args(parser=None) -> ParlaiParser:
train.add_argument('--display-examples', type='bool', default=False, hidden=True)
train.add_argument('-eps', '--num-epochs', type=float, default=-1)
train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
train.add_argument(
'-tstep',
'--max-train-steps',
type=int,
default=-1,
help='End training after n model updates',
)
train.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
train.add_argument(
'-lstep',
'--log-every-n-steps',
type=int,
default=-1,
help='Log every n training steps',
)
train.add_argument(
'-vtim',
'--validation-every-n-secs',
Expand All @@ -90,6 +104,14 @@ def setup_args(parser=None) -> ParlaiParser:
help='Validate every n seconds. Saves model to model_file '
'(if set) whenever best val metric is found',
)
train.add_argument(
'-vstep',
'--validation-every-n-steps',
type=float,
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
default=-1,
help='Validate every n training steps. Saves model to model_file '
'(if set) whenever best val metric is found',
)
train.add_argument(
'-stim',
'--save-every-n-secs',
Expand Down Expand Up @@ -289,28 +311,22 @@ def __init__(self, opt):
self.save_time = Timer()

self.parleys = 0
self.max_num_epochs = (
opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
)
self.max_train_time = (
opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
)
self.log_every_n_secs = (
opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
)
self.val_every_n_secs = (
opt['validation_every_n_secs']
if opt['validation_every_n_secs'] > 0
else float('inf')
)
self.save_every_n_secs = (
opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
)
self.val_every_n_epochs = (
opt['validation_every_n_epochs']
if opt['validation_every_n_epochs'] > 0
else float('inf')
)
self._train_steps = 0
self._last_log_steps = 0
self.update_freq = opt.get('update_freq', 1)

def _num_else_inf(key: str):
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
return opt[key] if opt[key] > 0 else float('inf')

self.max_num_epochs = _num_else_inf('num_epochs')
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self.max_train_time = _num_else_inf('max_train_time')
self.max_train_steps = _num_else_inf('max_train_steps')
self.log_every_n_secs = _num_else_inf('log_every_n_secs')
self.log_every_n_steps = _num_else_inf('log_every_n_steps')
self.val_every_n_secs = _num_else_inf('validation_every_n_secs')
self.val_every_n_epochs = _num_else_inf('validation_every_n_epochs')
self.val_every_n_steps = _num_else_inf('validation_every_n_steps')
self.save_every_n_secs = _num_else_inf('save_every_n_secs')

# smart defaults for --validation-metric-mode
if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
Expand All @@ -321,6 +337,7 @@ def __init__(self, opt):
opt['validation_metric_mode'] = 'max'

self.last_valid_epoch = 0
self._last_valid_steps = 0
self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
self.train_reports = []
self.valid_reports = []
Expand All @@ -343,6 +360,7 @@ def __init__(self, opt):
self.parleys = obj.get('parleys', 0)
self._preempted_epochs = obj.get('total_epochs', 0)
self.train_time.total = obj.get('train_time', 0)
self._train_steps = obj.get('train_steps', 0)
self.impatience = obj.get('impatience', 0)
self.valid_reports = obj.get('valid_reports', [])
self.train_reports = obj.get('train_reports', [])
Expand Down Expand Up @@ -397,6 +415,7 @@ def _save_train_stats(self, suffix=None):
{
'parleys': self.parleys,
'train_time': self.train_time.time(),
'train_steps': self._train_steps,
'total_epochs': self._total_epochs,
'train_reports': self.train_reports,
'valid_reports': self.valid_reports,
Expand Down Expand Up @@ -426,6 +445,7 @@ def validate(self):
v = dict_report(valid_report)
v['train_time'] = self.train_time.time()
v['parleys'] = self.parleys
v['train_steps'] = self._train_steps
v['total_exs'] = self._total_exs
v['total_epochs'] = self._total_epochs
self.valid_reports.append(v)
Expand Down Expand Up @@ -579,7 +599,9 @@ def _sync_metrics(self, metrics):
all_versions = all_gather_list(metrics)
return aggregate_unnamed_reports(all_versions)

def _compute_eta(self, epochs_completed, time_elapsed):
def _compute_eta(
self, epochs_completed: float, time_elapsed: float, steps_taken: int
):
"""
Compute the estimated seconds remaining in training.

Expand All @@ -602,6 +624,11 @@ def _compute_eta(self, epochs_completed, time_elapsed):
if eta is None or time_left < eta:
eta = time_left

max_train_steps = self.opt.get('max_train_steps', -1)
if max_train_steps > 0:
steps_progress = steps_taken / max_train_steps
eta = (1 - steps_progress) * time_elapsed / steps_progress

return eta

def log(self):
Expand All @@ -621,24 +648,29 @@ def log(self):
train_report_trainstats['total_epochs'] = self._total_epochs
train_report_trainstats['total_exs'] = self._total_exs
train_report_trainstats['parleys'] = self.parleys
train_report_trainstats['train_steps'] = self._train_steps
train_report_trainstats['train_time'] = self.train_time.time()
self.train_reports.append(train_report_trainstats)

# time elapsed
logs.append(f'time:{self.train_time.time():.0f}s')
logs.append(f'total_exs:{self._total_exs}')
logs.append(f'total_steps:{self._train_steps}')

if self._total_epochs >= 0:
# only if it's unbounded
logs.append(f'epochs:{self._total_epochs:.2f}')

time_left = self._compute_eta(self._total_epochs, self.train_time.time())
time_left = self._compute_eta(
self._total_epochs, self.train_time.time(), self._train_steps
)
if time_left is not None:
logs.append(f'time_left:{max(0,time_left):.0f}s')

log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
logging.info(log)
self.log_time.reset()
self._last_log_steps = 0

if opt['tensorboard_log'] and is_primary_worker():
self.tb_logger.log_metrics('train', self.parleys, train_report)
Expand All @@ -662,6 +694,8 @@ def train(self):
break

self.parleys += 1
self._train_steps = int(self.parleys / self.update_freq)
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self._last_log_steps += 1 / self.update_freq

# get the total training examples done, compute epochs
self._total_epochs = self._preempted_epochs + sum(
Expand All @@ -670,13 +704,27 @@ def train(self):
exs_per_epoch = world.num_examples()
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
# and use the primary worker's timings for everything
train_time, log_time, validate_time = sync_object(
(
if any(
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
getattr(self, k) < float('inf')
for k in [
'max_train_time',
'log_every_n_secs',
'validation_every_n_secs',
]
):
train_time, log_time, validate_time = sync_object(
(
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)
else:
train_time, log_time, validate_time = (
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)

# check counters and timers
if self._total_epochs >= self.max_num_epochs:
Expand All @@ -688,12 +736,20 @@ def train(self):
if train_time > self.max_train_time:
logging.info(f'max_train_time elapsed:{train_time}s')
break
if log_time > self.log_every_n_secs:
if self._train_steps >= self.max_train_steps:
logging.info(f'max_train_steps elapsed:{self._train_steps}')
break
if (
log_time > self.log_every_n_secs
or self._last_log_steps >= self.log_every_n_steps
):
self.log()
if (
validate_time > self.val_every_n_secs
or self._total_epochs - self.last_valid_epoch
>= self.val_every_n_epochs
or self._train_steps - self._last_valid_steps
>= self.val_every_n_steps
):
try:
# log before we validate
Expand All @@ -705,6 +761,7 @@ def train(self):
# reset the log time because we logged right before validating
self.log_time.reset()
self.last_valid_epoch = self._total_epochs
self._last_valid_steps = self._train_steps
if stop_training:
break
# make sure metrics are clean before we log
Expand Down
54 changes: 53 additions & 1 deletion tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
Basic tests that ensure train_model.py behaves in predictable ways.
"""

import os
import unittest
import parlai.utils.testing as testing_utils
from parlai.core.worlds import create_task
Expand Down Expand Up @@ -112,6 +112,58 @@ def test_multitasking_id_overlap(self):
in str(context.exception)
)

def _test_opt_step_opts(self, update_freq: int):
"""
Test -tstep, -vstep, -lstep.

:param update_freq:
update frequency

We copy train_model from testing_utils to directly access train loop.
"""
import parlai.scripts.train_model as tms

num_train_steps = 1001
num_validations = 10
num_logs = 100

def get_tl(tmpdir):
opt = {
'task': 'integration_tests',
'model': 'parlai.agents.test_agents.test_agents:MockTrainUpdatesAgent',
'model_file': os.path.join(tmpdir, 'model'),
'dict_file': os.path.join(tmpdir, 'model.dict'),
# step opts
'max_train_steps': num_train_steps,
'validation_every_n_steps': int(num_train_steps / num_validations),
'log_every_n_steps': int(num_train_steps / num_logs),
'update_freq': update_freq,
}
parser = tms.setup_args()
parser.set_params(**opt)
popt = parser.parse_args([])
for k, v in opt.items():
popt[k] = v
return tms.TrainLoop(popt)

with testing_utils.capture_output(), testing_utils.tempdir() as tmpdir:
tl = get_tl(tmpdir)
valid, _ = tl.train()

self.assertEqual(
tl.valid_reports[-1]['total_train_updates'], num_train_steps - 1
)
self.assertEqual(len(tl.valid_reports), num_validations)
self.assertEqual(
len(tl.train_reports), num_logs + num_validations
) # log every valid as well

def test_opt_step(self):
self._test_opt_step_opts(1)

def test_opt_step_update_freq_2(self):
self._test_opt_step_opts(2)


if __name__ == '__main__':
unittest.main()