Skip to content

Commit

Permalink
Add LR Scheduler (#4694)
Browse files Browse the repository at this point in the history
Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Thiago Crepaldi <[email protected]>
  • Loading branch information
3 people committed Aug 14, 2020
1 parent 527ae5b commit 92c5ba4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
3 changes: 1 addition & 2 deletions orttraining/orttraining/python/training/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def __init__(self, name, params, defaults):
assert k in defaults, f"'params' has 'k' hyper parameter not present at 'defaults'"

self.name = name
self.lr = defaults['lr']
self.base_lrs = [defaults['lr']]
self.lr = float(defaults['lr'])
self.defaults = defaults
self.params = []

Expand Down
6 changes: 3 additions & 3 deletions orttraining/orttraining/python/training/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def get_last_lr(self):
r""" Return last computed learning rate by LR Scheduler"""
return self._last_lr

def _step(self, train_step_info):
r"""Private method called to update learning rate
def step(self, train_step_info):
r"""Public method called to update learning rate
NOTE: This class should never be called by the user.
NOTE: This class is used internally.
"""

# Store last lr for future inquiry
Expand Down
9 changes: 9 additions & 0 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def train_step(self, *args, **kwargs):
input_desc = [*self.model_desc.inputs, self.model_desc.learning_rate]
output_desc = self.model_desc.outputs

# Update Learning Rate if Necessary
if self.options.lr_scheduler:
self.options.lr_scheduler.step(self._train_step_info)

# Get data. CombineTorchModelLossFn takes label as last input and outputs loss first
input = self._prepare_model_input(input_desc, self.optim_config.lr, None, *args, **kwargs)

Expand All @@ -267,6 +271,11 @@ def train_step(self, *args, **kwargs):
# Run a train step and return
session_run_results = self._training_session_run_helper(True, input, input_desc,
output_desc, run_options)

# Train step incremented after first train step based on lr scheduler implementation
# which handles initial train step of 0.
self._train_step_info.step += 1

return session_run_results[output_desc.name][0] if len (session_run_results) == 1\
else [session_run_results[output_desc.name] for output_desc in self.model_desc.outputs]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,18 +450,26 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
# Emulate ORTTRainer.train_step() call that updates its train_step_info
train_step_info = TrainStepInfo(step=step, optimizer_config=optimizer_config)

lr_scheduler._step(train_step_info)
lr_scheduler.step(train_step_info)
lr_list = lr_scheduler.get_last_lr()
assert len(lr_list) == 1
assert_allclose(lr_list[0],
expected_values[step], rtol=rtol, err_msg="lr mismatch")


@pytest.mark.parametrize("step_fn", [
('train_step'),
('eval_step')
@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [
('train_step', None, None),
('eval_step', None, None),
('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testInstantiateORTTrainer(step_fn):
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
# Loading external TransformerModel model for testing
# A manual import is done as this example is not part of onnxruntime package,
# but resides on the onnxruntime repo
Expand All @@ -480,10 +488,21 @@ def testInstantiateORTTrainer(step_fn):
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2)
my_loss = ort_utils.my_loss
model_desc = ort_utils.transformer_model_description()
optim_config = optim.LambConfig()

# Create ORTTrainer
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss)

max_train_step = 1
warmup = 0.5
initial_lr = 1
optim_config = optim.SGDConfig(lr=initial_lr)
tolerance = 1e-4 # used in lr comparison

# Set up relevant options
opts = {}
if lr_scheduler:
max_train_step = 10
opts.update({'lr_scheduler' : lr_scheduler(max_train_step, warmup)})

opts = orttrainer.ORTTrainerOptions(opts)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)

# Preparing data
train_data, val_data, _ = utils.prepare_data('cpu', 20, 20)
Expand All @@ -495,8 +514,12 @@ def testInstantiateORTTrainer(step_fn):
output = trainer.eval_step(data, targets)
elif step_fn == 'train_step':
step_fn = trainer.train_step
data, targets = utils.get_batch(train_data, 0)
output = trainer.train_step(data, targets)
for i in range(max_train_step):
data, targets = utils.get_batch(train_data, 0)
output = trainer.train_step(data, targets)
if lr_scheduler:
lr_list = trainer.options.lr_scheduler.get_last_lr()
assert_allclose(lr_list[0], expected_lr_values[i], rtol=tolerance, err_msg="lr mismatch")
else:
raise ValueError('Invalid step_fn')
assert trainer._onnx_model is not None
Expand Down Expand Up @@ -549,3 +572,4 @@ def testInstantiateORTTrainer(step_fn):
assert (trainer_from_onnx._onnx_model == trainer._onnx_model)
assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph)
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))

0 comments on commit 92c5ba4

Please sign in to comment.