Skip to content

Commit

Permalink
Add support to fetches (#4777)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Aug 14, 2020
1 parent 60b0735 commit adad823
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 57 deletions.
83 changes: 50 additions & 33 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,39 @@ class TrainStepInfo(object):
such as :py:method:`._LRScheduler.get_lr` or :py:class:`.LossScaler.update`.
Args:
all_finite (bool): flag that indicates whether all gradients are still finite after last step
step (int): indicates current training step. Used for gradient accumulation
optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling
optimizer_config (optim._OptimizerConfig): reference to optimizer config
all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step
fetches (list of str, default is []): list of output names to fetch from train_step/eval_step
optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling
step (int): indicates current training step. Used for gradient accumulation
Example:
.. code-block:: python
info = TrainStepInfo(all_finite=True, step=0, optimization_step=0, optimizer_config=optim.SGDConfig(lr=0.01))
info = TrainStepInfo(optimizer_config=optim.SGDConfig(lr=0.01))
if info.all_finite:
print(f'Yay, all gradients are finite at {step} step!')
"""

def __init__(self, all_finite=None, step=None, optimization_step=None, optimizer_config=None):
assert all_finite is None or isinstance(all_finite, bool),\
"all_finite must be either None or a bool"
assert step is None or (isinstance(step, int) and step >= 0),\
"step must be either None or a positive int"
assert optimization_step is None or (isinstance(optimization_step, int) and step >= 0),\
"optimization_step must be either None or a positive int"
assert optimizer_config is None or isinstance(optimizer_config, optim._OptimizerConfig),\
"optimizer_config must be either None or optim._OptimizerConfig"
def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0):
assert isinstance(optimizer_config, optim._OptimizerConfig),\
"optimizer_config must be a optim._OptimizerConfig"
assert isinstance(all_finite, bool),\
"all_finite must be a bool"
assert isinstance(fetches, list) and all([isinstance(item, str) for item in fetches]),\
"fetches must be a list of str"
assert isinstance(optimization_step, int) and optimization_step >= 0,\
"optimization_step must be a positive int"
assert (isinstance(step, int) and step >= 0),\
"step must be a positive int"

self.optimizer_config = optimizer_config
self.all_finite = all_finite
self.step = step
self.fetches = fetches
self.optimization_step = optimization_step
self.optimizer_config = optimizer_config
self.step = step


class ORTTrainer(object):
Expand Down Expand Up @@ -193,7 +197,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
set_cuda_mem_limit(self.options.device.mem_limit)
set_cuda_device_id(_utils.get_device_index(self.options.device.id))

self._train_step_info = TrainStepInfo(all_finite=True, step=0, optimization_step=0, optimizer_config=self.optim_config)
self._train_step_info = TrainStepInfo(self.optim_config)
self._init_session()

def eval_step(self, *args, **kwargs):
Expand All @@ -218,8 +222,12 @@ def eval_step(self, *args, **kwargs):
raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported")

# Prepare input/output description
input_desc = self.model_desc.inputs
output_desc = self.model_desc.outputs
inputs_desc = self.model_desc.inputs
outputs_desc = self.model_desc.outputs
if self._train_step_info.fetches:
outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches]
if len(outputs_desc) != len(self._train_step_info.fetches):
raise RuntimeError("The specified fetches list contains invalid output names")

# Normalize input
if not isinstance(sample_input, (list, tuple)):
Expand All @@ -233,12 +241,12 @@ def eval_step(self, *args, **kwargs):
# Run a eval step and return
session_run_results = self._training_session_run_helper(False,
sample_input,
input_desc,
output_desc,
inputs_desc,
outputs_desc,
run_options)

# Output must be returned in the same order as defined in the model description
results = [session_run_results[output_desc.name] for output_desc in self.model_desc.outputs]
results = [session_run_results[o_desc.name] for o_desc in outputs_desc]
return results[0] if len (results) == 1 else results

def save_as_onnx(self, path):
Expand Down Expand Up @@ -289,8 +297,8 @@ def train_step(self, *args, **kwargs):
self._init_onnx_model(sample_input)

# Prepare inputs+lr and output descriptions
input_desc = self._model_desc_inputs_with_lr
output_desc = self.model_desc.outputs
inputs_desc = self._model_desc_inputs_with_lr
outputs_desc = self.model_desc.outputs

# Train step must be incremented *before* gradient accumulation code
# Gradients are accumulated when
Expand All @@ -300,13 +308,18 @@ def train_step(self, *args, **kwargs):

# RunOptions
run_options = None
if self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0:
mixed_precision_without_fetches = False
if self._train_step_info.fetches:
outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches]
if len(outputs_desc) != len(self._train_step_info.fetches):
raise RuntimeError("The specified fetches list contains invalid output names")
elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0:
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = True
output_desc = self._model_desc_outputs_with_gradient_accumulation
outputs_desc = self._model_desc_outputs_with_gradient_accumulation
elif self.options.mixed_precision.enabled:
output_desc = self._model_desc_outputs_with_is_finite
mixed_precision_without_fetches = True
outputs_desc = self._model_desc_outputs_with_is_finite

# Update Learning Rate if Necessary
if self.options.lr_scheduler:
Expand All @@ -318,19 +331,19 @@ def train_step(self, *args, **kwargs):
loss_scaler = self.options.mixed_precision.loss_scaler
assert loss_scaler, "Loss scaler is required when mixed precision is enabled"
loss_scale = torch.tensor([loss_scaler.loss_scale])
input_desc = self._model_desc_inputs_with_lr_and_loss_scale
inputs_desc = self._model_desc_inputs_with_lr_and_loss_scale

# Get data. CombineTorchModelLossFn takes label as last input and outputs loss first
input = self._prepare_model_input(input_desc, self.optim_config.lr, loss_scale, *args, **kwargs)
input = self._prepare_model_input(inputs_desc, self.optim_config.lr, loss_scale, *args, **kwargs)

# Normalize input
if not isinstance(args, (list, tuple)):
args = (args,)

# Run a train step and return
session_run_results = self._training_session_run_helper(True, input, input_desc,
output_desc, run_options)
if self.options.mixed_precision.enabled:
session_run_results = self._training_session_run_helper(True, input, inputs_desc,
outputs_desc, run_options)
if mixed_precision_without_fetches:
# After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output
# Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce
# because all_fp32_gradients_finite is still in the feed.
Expand All @@ -348,7 +361,11 @@ def train_step(self, *args, **kwargs):
self._train_step_info.optimization_step += 1

# Output must be returned in the same order as defined in the model description
results = [session_run_results[output_desc.name] for output_desc in self.model_desc.outputs]
# or in the order specified by TrainStepInfo.fetches, if applicable
if self._train_step_info.fetches:
results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches]
else:
results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs]
return results[0] if len (results) == 1 else results

def _combine_torch_model_with_loss_fn(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ def testDynamicLossScaler():
default_scaler = amp.loss_scaler.DynamicLossScaler()

# Initial state
train_step_info = orttrainer.TrainStepInfo(all_finite=True, step=0,
optimization_step=0,
optimizer_config=None)
train_step_info = orttrainer.TrainStepInfo(optim.LambConfig())
assert_allclose(default_scaler.loss_scale, float(1 << 16),
rtol=rtol, err_msg="loss scale mismatch")
assert default_scaler.up_scale_window == 2000
Expand Down Expand Up @@ -300,31 +298,48 @@ def testDynamicLossScalerCustomValues():
def testTrainStepInfo():
'''Test valid initializations of TrainStepInfo'''

step_info = orttrainer.TrainStepInfo(all_finite=True, step=2, optimizer_config=optim.LambConfig())
assert step_info.all_finite is True
assert step_info.step == 2
assert isinstance(step_info.optimizer_config, optim._OptimizerConfig)

step_info = orttrainer.TrainStepInfo()
assert step_info.all_finite is None
assert step_info.step is None
assert step_info.optimizer_config is None


@pytest.mark.parametrize("test_input", [
optimizer_config = optim.LambConfig()
fetches=['out1','out2']
step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config,
all_finite=False,
fetches=fetches,
optimization_step=123,
step=456)
assert step_info.optimizer_config == optimizer_config
assert step_info.all_finite == False
assert step_info.fetches == fetches
assert step_info.optimization_step == 123
assert step_info.step == 456

step_info = orttrainer.TrainStepInfo(optimizer_config)
assert step_info.optimizer_config == optimizer_config
assert step_info.all_finite == True
assert step_info.fetches == []
assert step_info.optimization_step == 0
assert step_info.step == 0


@pytest.mark.parametrize("invalid_input", [
(-1),
('Hello'),
])
def testTrainStepInfoInvalidAllFinite(test_input):
def testTrainStepInfoInvalidInput(invalid_input):
'''Test invalid initialization of TrainStepInfo'''
optimizer_config = optim.LambConfig()
with pytest.raises(AssertionError):
orttrainer.TrainStepInfo(optimizer_config=invalid_input)

with pytest.raises(AssertionError):
orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input)

with pytest.raises(AssertionError):
orttrainer.TrainStepInfo(all_finite=test_input)
orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input)

with pytest.raises(AssertionError):
orttrainer.TrainStepInfo(step=test_input)
orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input)

with pytest.raises(AssertionError):
orttrainer.TrainStepInfo(optimizer_config=test_input)
orttrainer.TrainStepInfo(optimizer_config, step=invalid_input)


@pytest.mark.parametrize("optim_name,lr,alpha,default_alpha", [
Expand Down Expand Up @@ -503,7 +518,7 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
# First half is warmup
for optimization_step in range(total_steps):
# Emulate ORTTRainer.train_step() call that updates its train_step_info
train_step_info = TrainStepInfo(step=0, optimization_step=optimization_step, optimizer_config=optimizer_config)
train_step_info = TrainStepInfo(optimizer_config=optimizer_config, optimization_step=optimization_step)

lr_scheduler.step(train_step_info)
lr_list = lr_scheduler.get_last_lr()
Expand Down Expand Up @@ -651,10 +666,11 @@ def testORTDeterministicCompute(seed, device):
_test_helpers.assert_onnx_weights(first_trainer, second_trainer)


@pytest.mark.parametrize("seed,device,expected_loss", [
(321, 'cuda', [10.5774, 10.4403, 10.4175, 10.2886, 10.2760]),
@pytest.mark.parametrize("seed,device,expected_loss,fetches", [
(321, 'cuda', [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False),
(321, 'cuda', [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True),
])
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss):
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches):
total_steps = len(expected_loss)
torch.manual_seed(seed)
set_seed(seed)
Expand All @@ -675,9 +691,21 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss):
actual_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
loss, preds = trainer.train_step(data, targets)
if fetches:
trainer._train_step_info.fetches=['loss']
loss = trainer.train_step(data, targets)
else:
loss, _ = trainer.train_step(data, targets)
actual_loss.append(loss.cpu())

# Eval once just to test fetches in action
val_data, val_targets = batcher_fn(val_data, 0)
if fetches:
trainer._train_step_info.fetches=['loss']
loss = trainer.eval_step(val_data, val_targets)
trainer._train_step_info.fetches=[]
loss, preds = trainer.eval_step(val_data, val_targets)

# Compare loss to ground truth computed from current ORTTrainer API
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
assert trainer._onnx_model is not None
Expand Down

0 comments on commit adad823

Please sign in to comment.