Skip to content

Commit

Permalink
Add Mixed precision/LossScaler + several fixes (#4739)
Browse files Browse the repository at this point in the history
Additionally to the mixed precision/loss scaler code, this PR includes:

* Fix CUDA training
* Add optimization_step into TrainStepInfo class
* Refactor LRSCheduler to use optimization_step instead of step
* Updated several default values at ORTTrainerOptions
* Add initial Gradient Accumulation supported. Untested
* Fix ONNX model post processing
* Refactor unit tests
  • Loading branch information
Thiago Crepaldi committed Aug 14, 2020
1 parent 5ed1384 commit ab0249a
Show file tree
Hide file tree
Showing 10 changed files with 720 additions and 292 deletions.
24 changes: 24 additions & 0 deletions orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,42 @@


def get_device_index(device):
'''Returns device index from a device'''

if type(device) == str:
# Could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
return 0 if device.index is None else device.index


def get_device_index_from_input(input):
'''Returns device index from a input PyTorch Tensor'''

if isinstance(input, (list, tuple)):
device_index = get_device_index(input[0].device)
else:
device_index = get_device_index(input.device)
return device_index


def get_all_gradients_finite_name_from_session(session):
'''Find all_gradients_finite node on Session graph and return its name'''

nodes = [x for x in session._outputs_meta if 'all_gradients_finite' in x.name]
if len(nodes) != 1:
raise RuntimeError("'all_gradients_finite' node not found within training session")
return nodes[0].name


def get_gradient_accumulation_name_from_session(session):
'''Find Group_Accumulated_Gradients node on Session graph and return its name'''

nodes = [x for x in session._outputs_meta if 'Group_Accumulated_Gradients' in x.name]
if len(nodes) != 1:
raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session")
return nodes[0].name


def dtype_torch_to_numpy(torch_dtype):
'''Converts PyTorch types to Numpy types
Expand Down Expand Up @@ -140,6 +162,8 @@ def decorate(func):


def import_module_from_file(file_path, module_name):
'''Import a Python module from a file into interpreter'''

assert isinstance(file_path, str) and os.path.exists(file_path),\
"'file_path' must be a full path string with the python file to load"
assert isinstance(module_name, str) and module_name,\
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/python/training/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import loss_scaler
from .loss_scaler import LossScaler, DynamicLossScaler
31 changes: 27 additions & 4 deletions orttraining/orttraining/python/training/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,28 @@ class LossScaler(object):
This class should never be instantiated, but used as an abstract class for custom loss scaling strategy.
"""

def __init__(self):
pass
def __init__(self, loss_scale):
self._input_name = None
self._loss_scale = loss_scale

@property
def input_name(self):
return self._input_name

@input_name.setter
def input_name(self, input_name):
assert isinstance(input_name, str), "'input_name' must be a string"
assert input_name is None or len(input_name) > 0, "'input_name' cannot be empty"
self._input_name = input_name

@property
def loss_scale(self):
return self._loss_scale

@loss_scale.setter
def loss_scale(self, loss_scale):
assert isinstance(loss_scale, float) and loss_scale > 0, "'loss_scale' must be a positive float"
self._loss_scale = loss_scale

def reset(self):
r"""Resets loss scaler internal state"""
Expand All @@ -19,6 +39,9 @@ def update(self, train_step_info):
Args:
train_step_info (TrainStepInfo): last step state information
Returns:
Updated loss scale (float)
"""
raise NotImplementedError

Expand Down Expand Up @@ -65,9 +88,8 @@ def __init__(self, automatic_update=True,
up_scale_window=2000,
min_loss_scale=1.0,
max_loss_scale=float(1 << 24)):
super().__init__()
super().__init__(loss_scale)
self.automatic_update = automatic_update
self.loss_scale = loss_scale
self.up_scale_window = up_scale_window
self.min_loss_scale = min_loss_scale
self.max_loss_scale = max_loss_scale
Expand All @@ -89,3 +111,4 @@ def update(self, train_step_info):
else:
self.loss_scale = max(self.min_loss_scale, self.loss_scale / 2)
self._stable_steps_count = 0
return self.loss_scale
73 changes: 53 additions & 20 deletions orttraining/orttraining/python/training/debug.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy as np
import os
import sys
Expand All @@ -8,44 +7,78 @@
from onnxruntime.capi.training import orttrainer
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer

def compare_onnx_weights(model_a, model_b, verbose=False, rtol=1e-4):
r"""Compare whether weights between 'model_a' and 'model_b' ONNX models are within
a certain tolerance 'rtol'

Compares the weights of two different ONNX models and throws an error when they diverge
def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether output_a and output_b difference is within specified tolerance
Args:
output_a, output_b (list): Two list with of numeric values
verbose (bool, default is False): if True, prints absolute difference for each weight
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(output_a, list) and isinstance(output_b, list),\
"output_a and output_b must be a list of numbers"
assert len(output_a) == len(output_b), "output_a and output_b must have the same length"

# for idx in range(len(output_a)):
assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg=f"Output mismatch at position")

def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of two different ONNX models (model_a and model_b)
and raises AssertError when they diverge by more than atol or rtol
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): Indicates if the max absolute difference for each layer should be
calculated and printed for debug information.
rtol (float, default is 1e-4): Tolerance for divergence.
verbose (bool, default is False): if True, prints absolute difference for each weight
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_compare_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol)
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)


def compare_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-4):
r"""Compare whether weights between 'model_a' (legacy API ONNX model) and 'model_b' (new API ONNX model)
are within a certain tolerance 'rtol'
def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of a legacy model model_a and experimental model_b model
and raises AssertError when they diverge by more than atol or rtol.
Compares the weights of two different ONNX models and throws an error when they diverge
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): Indicates if the max absolute difference for each layer should be
calculated and printed for debug information.
rtol (float, default is 1e-4): Tolerance for divergence.
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_compare_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol)
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)


def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol):
r"""Asserts whether dicts a and b value differences are within specified tolerance
Compares the weights of two model's state_dict dicts and raises AssertError
when they diverge by more than atol or rtol
Args:
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""

def _compare_state_dict_weights(state_dict_a, state_dict_b, verbose=False, rtol=1e-4):
for (a_name, a_val), (b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
np_a_vals = np.array(a_val).flatten()
np_b_vals = np.array(b_val).flatten()
assert np_a_vals.shape == np_b_vals.shape
if verbose:
print(f'Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}')
assert_allclose(a_val, b_val, rtol=rtol, err_msg=f"Weight mismatch for {a_name}")
assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}")
Loading

0 comments on commit ab0249a

Please sign in to comment.