Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Debuggability to Check Model Weights #4716

31 changes: 31 additions & 0 deletions orttraining/orttraining/python/training/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import numpy as np
import os
import sys
import torch

from numpy.testing import assert_allclose
from onnxruntime.capi.training import orttrainer

def compare_onnx_weights(model_a, model_b, verbose=False, rtol=1e-4):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
r"""Compare whether weights between 'model_a' and 'model_b' ONNX models are within
a certain tolerance 'rtol'

Compares the weights of two different ONNX models and throws an error when they diverge
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): Indicates if the max absolute difference for each layer should be
calculated and printed for debug information.
rtol (float, default is 1e-4): Tolerance for divergence.
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
for (a_name, a_val), (b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
np_a_vals = np.array(a_val).flatten()
np_b_vals = np.array(b_val).flatten()
assert np_a_vals.shape == np_b_vals.shape
if verbose:
print(f'Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}')
assert_allclose(a_val, b_val, rtol=rtol, err_msg=f"Weight mismatch for {a_name}")

4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
if 'cuda' in device_id.lower():
set_cuda_mem_limit(int(self.options.device.mem_limit))
if ':' in device_id:
set_cuda_device_id(device_id.split(':')[1])
set_cuda_device_id(int(device_id.split(':')[1]))

self._train_step_info = TrainStepInfo(all_finite=True, step=0, optimizer_config=self.optim_config)

Expand Down Expand Up @@ -345,7 +345,7 @@ def forward(self, *inputs):
return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, inputs):
device = torch.device(self.options.device.id)
device = torch.device('cpu') #torch.device(self.options.device.id)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from onnxruntime.capi.training import orttrainer_options as orttrainer_options
from onnxruntime.capi.training import model_desc_validation as md_val
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils, debug
from onnxruntime.capi._pybind_state import set_seed


@pytest.mark.parametrize("test_input", [
Expand Down Expand Up @@ -456,20 +457,7 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
assert_allclose(lr_list[0],
expected_values[step], rtol=rtol, err_msg="lr mismatch")


@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [
('train_step', None, None),
('eval_step', None, None),
('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
def generate_pytorch_transformer_model_sample(optim_config, options={}, step_fn='train_step', device='cpu'):
# Loading external TransformerModel model for testing
# A manual import is done as this example is not part of onnxruntime package,
# but resides on the onnxruntime repo
Expand All @@ -489,6 +477,36 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
my_loss = ort_utils.my_loss
model_desc = ort_utils.transformer_model_description()

# Set up relevant options
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

# Preparing data
train_data, val_data, _ = utils.prepare_data(device, 20, 20)

if step_fn == 'eval_step':
data, targets = utils.get_batch(val_data, 0)
elif step_fn == 'train_step':
data, targets = utils.get_batch(train_data, 0)
else:
raise ValueError('Invalid step_fn')

data, targets = data.to(trainer.options.device.id), targets.to(trainer.options.device.id)

return model, model_desc, trainer, data, targets

@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [
('train_step', None, None),
('eval_step', None, None),
('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
max_train_step = 1
warmup = 0.5
initial_lr = 1
Expand All @@ -502,20 +520,17 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
opts.update({'lr_scheduler' : lr_scheduler(max_train_step, warmup)})

opts = orttrainer.ORTTrainerOptions(opts)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)

# Preparing data
train_data, val_data, _ = utils.prepare_data('cpu', 20, 20)
# Using PyTorch Transformer model as example
model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, step_fn)
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

# Export model to ONNX
if step_fn == 'eval_step':
step_fn = trainer.eval_step
data, targets = utils.get_batch(val_data, 0)
output = trainer.eval_step(data, targets)
elif step_fn == 'train_step':
step_fn = trainer.train_step
for i in range(max_train_step):
data, targets = utils.get_batch(train_data, 0)
output = trainer.train_step(data, targets)
if lr_scheduler:
lr_list = trainer.options.lr_scheduler.get_last_lr()
Expand Down Expand Up @@ -573,3 +588,43 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph)
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))


@pytest.mark.parametrize("seed, device_id", [
(0, 'cpu'),
(42, 'cpu'),
(0, 'cuda:0'),
(24, 'cuda')
])
def testORTDeterministicCompute(seed, device_id):
optim_config = optim.LambConfig()
opts = orttrainer.ORTTrainerOptions({
'debug' : {
'deterministic_compute': True
},
'device' : {
'id' : device_id,
'mem_limit' : 10*1024*1024
}
})

torch.manual_seed(seed)
set_seed(seed)

# Using PyTorch Transformer model as example
model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id)
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

# Run first model train step
output = trainer.train_step(data, targets)
assert trainer._onnx_model is not None

# Reset the seeds
torch.manual_seed(seed)
set_seed(seed)

# Run second model train step
_, _, second_trainer, _, _ = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id)
output = second_trainer.train_step(data, targets)
assert second_trainer._onnx_model is not None
assert id(trainer._onnx_model) != id(second_trainer._onnx_model)

debug.compare_onnx_weights(trainer, second_trainer)