Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Loss Scaler classes to the new frontend #4306

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions orttraining/orttraining/python/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
from onnxruntime.capi._pybind_state import TrainingParameters
from onnxruntime.capi.training.training_session import TrainingSession


from .orttrainer_options import ORTTrainerOptions
from . import model_desc_validation
from . import amp
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions orttraining/orttraining/python/training/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import loss_scaler
91 changes: 90 additions & 1 deletion orttraining/orttraining/python/training/amp/loss_scaler.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,91 @@
class LossScaler(object):
pass
r"""Base class for implementing custom loss scaler strategies

Once the scaler is configured, no user intervention is needed to update loss scale during training.

Note:
This class should never be instantiated, but used as an abstract class for custom loss scaling strategy.
"""

def __init__(self):
pass

def reset(self):
r"""Resets loss scaler internal state"""
raise NotImplementedError

def update(self, train_step_info):
r"""Updates loss based on user input and training session info

Args:
train_step_info (TrainStepInfo): last step state information
"""
raise NotImplementedError


class DynamicLossScaler(LossScaler):
r"""Default implementation for :py:class:`.LossScaler` class used for mixed precision

This loss scaler works by assuming an initial scale, which is doubled every time a certain number of
(stable) training steps are performed without exploding gradients (overflow or reach infinity).
When at least one of the gradients explode, loss scale is divided by 2.

Users can use this class in two ways:

1. Enable mixed precision and not setting a loss scaler class. Default values are used
2. Enable mixed precision and instantiate this class to override default arguments

Static loss scaling can be achieved by setting :py:attr:`.automatic_update` to :py:obj:`False`
and not performing manual :py:meth:`update` in train loop.

Args:
automatic_update (bool, default is False): boolean switch that allows :py:meth:`ORTTrainer.train_step`
to automatically perform loss scaling. If False, an explicit call to :py:meth:`.update` must be done by the user,
otherwise static loss scaling is performed.
loss_scale (default is 1 << 16): A float that represents current loss scale
up_scale_window (int, default is 2000): number of stable train steps before doubling loss scale
min_loss_scale (float, default is 1): min value for the loss scale. Used when loss scale is decreased
max_loss_scale (float, default is 1 << 24): max value for the loss scale. Used when loss scale is increased

Example with default values:
.. code-block:: python

scaler1 = amp.DynamicLossScaler()
print(f'Default loss scale is {scaler1.loss_scale}')

Example with user specified values:
.. code-block:: python

scaler2 = amp.DynamicLossScaler(loss_scale=1<<8)
print(f'Custom loss scale is {scaler2.loss_scale}')
"""

def __init__(self, automatic_update=True,
loss_scale=float(1 << 16),
up_scale_window=2000,
min_loss_scale=1.0,
max_loss_scale=float(1 << 24)):
super().__init__()
self.automatic_update = automatic_update
self.loss_scale = loss_scale
self.up_scale_window = up_scale_window
self.min_loss_scale = min_loss_scale
self.max_loss_scale = max_loss_scale

self._initial_loss_scale = loss_scale
self._stable_steps_count = 0

def reset(self):
self.loss_scale = self._initial_loss_scale
self._stable_steps_count = 0

def update(self, train_step_info):
if train_step_info.all_finite:
self._stable_steps_count += 1

if self._stable_steps_count >= self.up_scale_window:
self.loss_scale = min(self.max_loss_scale, self.loss_scale * 2)
self._stable_steps_count = 0
else:
self.loss_scale = max(self.min_loss_scale, self.loss_scale / 2)
self._stable_steps_count = 0
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch
from numpy.testing import assert_allclose

from onnxruntime.capi.training import orttrainer_options as orttrainer_options
from onnxruntime.capi.training import model_desc_validation as md_val
from onnxruntime.capi.training import orttrainer, amp


@pytest.mark.parametrize("test_input", [
Expand Down Expand Up @@ -97,3 +99,90 @@ def testORTTrainerModelDescInvalidSchemas(test_input, error_msg):
with pytest.raises(ValueError) as e:
md_val._ORTTrainerModelDesc(test_input)
assert str(e.value) == error_msg


def testDynamicLossScaler():
rtol = 1e-5
default_scaler = amp.loss_scaler.DynamicLossScaler()

# Initial state
train_step_info = orttrainer.TrainStepInfo(
all_finite=True, epoch=0, step=0)
assert_allclose(default_scaler.loss_scale, float(1 << 16),
rtol=rtol, err_msg="loss scale mismatch")
assert default_scaler.up_scale_window == 2000
assert_allclose(default_scaler.min_loss_scale, 1.0,
rtol=rtol, err_msg="min loss scale mismatch")
assert_allclose(default_scaler.max_loss_scale, float(
1 << 24), rtol=rtol, err_msg="max loss scale mismatch")

# Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True)
loss_scale = float(1 << 16)
for cycles in range(1, 10):

# 1999 updates without overflow produces 1999 stable steps
for i in range(1, 2000):
default_scaler.update(train_step_info)
assert default_scaler._stable_steps_count == i
assert_allclose(default_scaler.loss_scale, loss_scale,
rtol=rtol, err_msg=f"loss scale mismatch at update {i}")

# 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached
default_scaler.update(train_step_info)
if cycles <= 8:
loss_scale *= 2
assert default_scaler._stable_steps_count == 0
assert_allclose(default_scaler.loss_scale, loss_scale,
rtol=rtol, err_msg="loss scale mismatch")

# After 8 cycles, loss scale should be float(1 << 16)*(2**8)
assert_allclose(default_scaler.loss_scale, float(1 << 16)
* (2**8), rtol=rtol, err_msg="loss scale mismatch")

# After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on
loss_scale = float(1 << 16)*(2**8)
for count in range(1, 2050):
default_scaler.update(train_step_info)
assert default_scaler._stable_steps_count == (count % 2000)
assert_allclose(default_scaler.loss_scale, loss_scale,
rtol=rtol, err_msg="loss scale mismatch")

# Setting train_step_info.all_finite = False to test down scaling
train_step_info.all_finite = False

# Performing 24 updates to half the loss scale each time
loss_scale = float(1 << 16)*(2**8)
for count in range(1, 25):
default_scaler.update(train_step_info)
loss_scale /= 2
assert default_scaler._stable_steps_count == 0
assert_allclose(default_scaler.loss_scale, loss_scale,
rtol=rtol, err_msg="loss scale mismatch")

# After 24 updates with gradient overflow, loss scale is 1.0
assert_allclose(default_scaler.loss_scale, 1.,
rtol=rtol, err_msg="loss scale mismatch")

# After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on
for count in range(1, 5):
default_scaler.update(train_step_info)
assert default_scaler._stable_steps_count == 0
assert_allclose(default_scaler.loss_scale, loss_scale,
rtol=rtol, err_msg="loss scale mismatch")


def testDynamicLossScalerCustomValues():
rtol = 1e-5
scaler = amp.loss_scaler.DynamicLossScaler(automatic_update=False,
loss_scale=3,
up_scale_window=7,
min_loss_scale=5,
max_loss_scale=10)
assert scaler.automatic_update == False
assert_allclose(scaler.loss_scale, 3, rtol=rtol,
err_msg="loss scale mismatch")
assert_allclose(scaler.min_loss_scale, 5, rtol=rtol,
err_msg="min loss scale mismatch")
assert_allclose(scaler.max_loss_scale, 10, rtol=rtol,
err_msg="max loss scale mismatch")
assert scaler.up_scale_window == 7