Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[train] New training options for logging/validation based on number of steps #3379

Merged
merged 12 commits into from
Mar 8, 2021
16 changes: 16 additions & 0 deletions parlai/agents/test_agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def eval_step(self, batch):
Null output.
"""
return Output()


class MockTrainUpdatesAgent(MockTorchAgent):
"""
Simulate training updates.
"""

def train_step(self, batch):
ret = super().train_step(batch)
update_freq = self.opt.get('update_freq', 1)
if update_freq == 1:
self._number_training_updates += 1
else:
self._number_grad_accum = (self._number_grad_accum + 1) % update_freq
self._number_training_updates += int(self._number_grad_accum == 0)
return ret
14 changes: 5 additions & 9 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,6 @@ def add_cmdline_args(
help='Decay factor for LR scheduler, or how much LR is multiplied by '
'when it is lowered.',
)
lr_group.add_argument(
'--max-lr-steps',
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=-1,
help='Number of train steps the scheduler should take after warmup. '
'Training is terminated after this many steps. This should only be '
'set for --lr-scheduler cosine or linear',
)
lr_group.add_argument(
'--invsqrt-lr-decay-gamma',
type=int,
Expand Down Expand Up @@ -222,7 +214,11 @@ def lr_scheduler_factory(cls, opt, optimizer, states, hard_reset=False):
decay = opt.get('lr_scheduler_decay', 0.5)
warmup_updates = opt.get('warmup_updates', -1)
warmup_rate = opt.get('warmup_rate', 1e-4)
max_lr_steps = opt.get('max_lr_steps', -1)
max_lr_steps = opt.get('max_train_steps', -1)
if opt.get('max_lr_steps', -1) > 0:
raise ValueError(
'--max-lr-steps is **DEPRECATED**; please set --max-train-steps directly'
)
invsqrt_lr_decay_gamma = opt.get('invsqrt_lr_decay_gamma', -1)

if opt.get('lr_scheduler') == 'none':
Expand Down
2 changes: 1 addition & 1 deletion parlai/scripts/build_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup_args(parser=None, hidden=True):
hidden=hidden,
)
dict_loop.add_argument(
'-ltim', '--log-every-n-secs', type=float, default=10, hidden=hidden
'-ltim', '--log-every-n-secs', type=float, default=-1, hidden=hidden
)
DictionaryAgent.add_cmdline_args(parser, partial_opt=None)
return parser
Expand Down
170 changes: 137 additions & 33 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import json
import numpy as np
import signal
from typing import Tuple

from parlai.core.metrics import Metric
from parlai.core.agents import create_agent, create_agent_from_shared
Expand All @@ -38,8 +39,9 @@
aggregate_unnamed_reports,
dict_report,
)
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser, print_announcements
from parlai.core.worlds import create_task
from parlai.core.worlds import create_task, World
from parlai.scripts.build_dict import build_dict, setup_args as setup_dict_args
from parlai.utils.distributed import (
sync_object,
Expand All @@ -54,6 +56,20 @@
from parlai.utils.io import PathManager


def _num_else_inf(opt: Opt, key: str, distributed_warn=False):
if opt[key] > 0:
if distributed_warn and is_distributed():
nicekey = '--' + key.replace('_', '-')
logging.warn(
f'Using {nicekey} in distributed mode can lead to slowdowns. '
'See https://github.com/facebookresearch/ParlAI/pull/3379 for more info.'
)
value = opt[key]
else:
value = float('inf')
return value


def setup_args(parser=None) -> ParlaiParser:
"""
Build the ParlAI parser, adding command line args if necessary.
Expand Down Expand Up @@ -92,7 +108,22 @@ def setup_args(parser=None) -> ParlaiParser:
train.add_argument('--display-examples', type='bool', default=False, hidden=True)
train.add_argument('-eps', '--num-epochs', type=float, default=-1)
train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
train.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
train.add_argument(
'-tstep',
'--max-train-steps',
'--max-lr-steps',
type=int,
default=-1,
help='End training after n model updates',
)
train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1)
train.add_argument(
'-lstep',
'--log-every-n-steps',
type=int,
default=50,
help='Log every n training steps',
)
train.add_argument(
'-vtim',
'--validation-every-n-secs',
Expand All @@ -101,6 +132,14 @@ def setup_args(parser=None) -> ParlaiParser:
help='Validate every n seconds. Saves model to model_file '
'(if set) whenever best val metric is found',
)
train.add_argument(
'-vstep',
'--validation-every-n-steps',
type=int,
default=-1,
help='Validate every n training steps. Saves model to model_file '
'(if set) whenever best val metric is found',
)
train.add_argument(
'-stim',
'--save-every-n-secs',
Expand Down Expand Up @@ -310,27 +349,28 @@ def __init__(self, opt):
self.save_time = Timer()

self.parleys = 0
self.max_num_epochs = (
opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
)
self.max_train_time = (
opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
self._train_steps = 0
self._last_log_steps = 0
self.update_freq = opt.get('update_freq', 1)

self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
self.max_train_time = _num_else_inf(
opt, 'max_train_time', distributed_warn=True
)
self.log_every_n_secs = (
opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
self.log_every_n_secs = _num_else_inf(
opt, 'log_every_n_secs', distributed_warn=True
)
self.val_every_n_secs = (
opt['validation_every_n_secs']
if opt['validation_every_n_secs'] > 0
else float('inf')
self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
self.val_every_n_secs = _num_else_inf(
opt, 'validation_every_n_secs', distributed_warn=True
)
self.save_every_n_secs = (
opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
self.val_every_n_epochs = _num_else_inf(
opt, 'validation_every_n_epochs', distributed_warn=True
)
self.val_every_n_epochs = (
opt['validation_every_n_epochs']
if opt['validation_every_n_epochs'] > 0
else float('inf')
self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
self.save_every_n_secs = _num_else_inf(
opt, 'save_every_n_secs', distributed_warn=True
)

# smart defaults for --validation-metric-mode
Expand All @@ -342,6 +382,7 @@ def __init__(self, opt):
opt['validation_metric_mode'] = 'max'

self.last_valid_epoch = 0
self._last_valid_steps = 0
self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
self.train_reports = []
self.valid_reports = []
Expand All @@ -364,6 +405,7 @@ def __init__(self, opt):
self.parleys = obj.get('parleys', 0)
self._preempted_epochs = obj.get('total_epochs', 0)
self.train_time.total = obj.get('train_time', 0)
self._train_steps = obj.get('train_steps', 0)
self.impatience = obj.get('impatience', 0)
self.valid_reports = obj.get('valid_reports', [])
if self.valid_reports:
Expand Down Expand Up @@ -425,6 +467,7 @@ def _save_train_stats(self, suffix=None):
{
'parleys': self.parleys,
'train_time': self.train_time.time(),
'train_steps': self._train_steps,
'total_epochs': self._total_epochs,
'train_reports': self.train_reports,
'valid_reports': self.valid_reports,
Expand Down Expand Up @@ -454,6 +497,7 @@ def validate(self):
v = dict_report(valid_report)
v['train_time'] = self.train_time.time()
v['parleys'] = self.parleys
v['train_steps'] = self._train_steps
v['total_exs'] = self._total_exs
v['total_epochs'] = self._total_epochs
self.valid_reports.append(v)
Expand Down Expand Up @@ -611,7 +655,9 @@ def _sync_metrics(self, metrics):
all_versions = all_gather_list(metrics)
return aggregate_unnamed_reports(all_versions)

def _compute_eta(self, epochs_completed, time_elapsed):
def _compute_eta(
self, epochs_completed: float, time_elapsed: float, steps_taken: int
):
"""
Compute the estimated seconds remaining in training.

Expand All @@ -634,8 +680,59 @@ def _compute_eta(self, epochs_completed, time_elapsed):
if eta is None or time_left < eta:
eta = time_left

max_train_steps = self.opt.get('max_train_steps', -1)
if max_train_steps > 0:
steps_progress = steps_taken / max_train_steps
eta = (1 - steps_progress) * time_elapsed / steps_progress

return eta

def _get_time(self, world: World) -> Tuple[float, float, float]:
"""
Return train, log, and validate timing.

If relying on the time for validation/logging/max train time purposes,
we sync and return primary worker's time.

Otherwise, it's not super relevant what we do here.

**SIDE EFFECT**: Update _total_epochs trained.

:param world:
current running world

:return (train, log, valid):
return time for each of train, log, and validation
"""
if (
self.max_train_time < float('inf')
or self.log_every_n_secs < float('inf')
or self.val_every_n_secs < float('inf')
or self.val_every_n_epochs < float('inf')
or self.max_num_epochs < float('inf')
):
self._total_epochs = self._preempted_epochs + sum(
all_gather_list(world.get_total_epochs())
)
train_time, log_time, validate_time = sync_object(
(
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)
else:
train_time, log_time, validate_time = (
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
self._total_epochs = self._preempted_epochs + (
num_workers() * world.get_total_epochs()
)

return train_time, log_time, validate_time

def log(self):
"""
Output a training log entry.
Expand All @@ -653,24 +750,29 @@ def log(self):
train_report_trainstats['total_epochs'] = self._total_epochs
train_report_trainstats['total_exs'] = self._total_exs
train_report_trainstats['parleys'] = self.parleys
train_report_trainstats['train_steps'] = self._train_steps
train_report_trainstats['train_time'] = self.train_time.time()
self.train_reports.append(train_report_trainstats)

# time elapsed
logs.append(f'time:{self.train_time.time():.0f}s')
logs.append(f'total_exs:{self._total_exs}')
logs.append(f'total_steps:{self._train_steps}')

if self._total_epochs >= 0:
# only if it's unbounded
logs.append(f'epochs:{self._total_epochs:.2f}')

time_left = self._compute_eta(self._total_epochs, self.train_time.time())
time_left = self._compute_eta(
self._total_epochs, self.train_time.time(), self._train_steps
)
if time_left is not None:
logs.append(f'time_left:{max(0,time_left):.0f}s')

log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
logging.info(log)
self.log_time.reset()
self._last_log_steps = 0

if opt['tensorboard_log'] and is_primary_worker():
self.tb_logger.log_metrics('train', self.parleys, train_report)
Expand All @@ -696,21 +798,14 @@ def train(self):
break

self.parleys += 1
self._train_steps = self.parleys // self.update_freq
self._last_log_steps += 1 / self.update_freq

# the following additionally updates self._total_epochs
train_time, log_time, validate_time = self._get_time(world)
# get the total training examples done, compute epochs
self._total_epochs = self._preempted_epochs + sum(
all_gather_list(world.get_total_epochs())
)
exs_per_epoch = world.num_examples()
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
# and use the primary worker's timings for everything
train_time, log_time, validate_time = sync_object(
(
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)

# check counters and timers
if self._total_epochs >= self.max_num_epochs:
Expand All @@ -722,12 +817,20 @@ def train(self):
if train_time > self.max_train_time:
logging.info(f'max_train_time elapsed:{train_time}s')
break
if log_time > self.log_every_n_secs:
if self._train_steps >= self.max_train_steps:
logging.info(f'max_train_steps elapsed:{self._train_steps}')
break
if (
log_time > self.log_every_n_secs
or self._last_log_steps >= self.log_every_n_steps
):
self.log()
if (
validate_time > self.val_every_n_secs
or self._total_epochs - self.last_valid_epoch
>= self.val_every_n_epochs
or self._train_steps - self._last_valid_steps
>= self.val_every_n_steps
):
try:
# log before we validate
Expand All @@ -739,6 +842,7 @@ def train(self):
# reset the log time because we logged right before validating
self.log_time.reset()
self.last_valid_epoch = self._total_epochs
self._last_valid_steps = self._train_steps
if stop_training:
break
# make sure metrics are clean before we log
Expand Down
Loading