Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create training session + minor improvements #4668

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, model_desc):
else:
self._validated['outputs'][idx] = self._OutputDescription(*output)

# Hard-code learning rate descriptor for the model
self._validated['learning_rate'] = self._InputDescriptionTyped('Learning_Rate', [1], torch.float32)

# Convert dict in object
for k, v in self._validated.items():
setattr(self, k, self._wrap(v))
Expand Down
102 changes: 89 additions & 13 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import torch
from inspect import signature

import onnxruntime as ort
from . import ORTTrainerOptions
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc
from .. import postprocess
from onnxruntime.capi._pybind_state import set_cuda_mem_limit
from onnxruntime.capi._pybind_state import set_cuda_device_id


class TrainStepInfo(object):
Expand Down Expand Up @@ -146,6 +149,13 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
self.optim_config = optim_config
self.options = options

# Set GPU device and memory limit
device_id = self.options.device.id
if 'cuda' in device_id.lower():
set_cuda_mem_limit(int(self.options.device.mem_limit))
if ':' in device_id:
set_cuda_device_id(device_id.split(':')[1])

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Expand All @@ -162,7 +172,6 @@ def eval_step(self, *input, **kwargs):
self.model_desc.inputs, None, None, *input, **kwargs)
self._init_onnx_model(sample_input)


def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
Expand Down Expand Up @@ -260,18 +269,16 @@ def forward(self, *inputs):

return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
def _convert_torch_model_loss_fn_to_onnx(self, inputs):
device = torch.device(self.options.device.id)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(
device=device) for k in self.model_desc.inputs]
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(
inputs) if i < len(self.model_desc.inputs)]
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
else:
raise RuntimeError(
"Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")

# PyTorch ONNX exporter does not match argument names
# This is an issue because the ONNX graph depends on all inputs to be specified
Expand Down Expand Up @@ -341,7 +348,7 @@ def _init_onnx_model(self, inputs):
self.options.utils.frozen_weights.extend(torch_buffers)

# Export to ONNX
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(torch.device('cpu'), inputs)
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs)

self._init_session()

Expand All @@ -356,18 +363,87 @@ def _init_session(self):
# Perform user-specified post-processing
if self.options._internal_use.extra_postprocess:
self.options._internal_use.extra_postprocess(self._onnx_model)

# Create training session used by train_step
self._create_ort_training_session()
return

def _prepare_model_input(self, inputs_desc, lr, loss_scale, *args, **kwargs):
def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):
# Normalize input to tuple of samples
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list:
input = tuple(inputs[0])
else:
input = args
input = inputs

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)

return input

# TODO: Test this througly along with train step, including
# various optimizer parameter groups, frozen weights, loss and lr
def _create_ort_training_session(self):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
if n not in [i.name for i in self._onnx_model.graph.initializer]]
if unused_frozen_weights:
raise RuntimeError("{} params from 'frozen_weights' not found in the ONNX model.".format(
unused_frozen_weights))

# Get loss name from model description
loss_name = [item.name for item in self.model_desc.outputs if len(item) == 4 and item[2]]
assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)"
loss_name = loss_name[0]

# Parse optimizer parameters
optimizer_attributes_map = {}
optimizer_int_attributes_map = {}
trainable_params = set()
for initializer in self._onnx_model.graph.initializer:
if initializer.name in self.options.utils.frozen_weights:
continue # only trainable parameters are passed to the backend
trainable_params.add(initializer.name)
optimizer_attributes_map[initializer.name] = {}
optimizer_int_attributes_map[initializer.name] = {}
for param_group in self.optim_config.params:
if initializer.name not in param_group['params']:
continue # keep looking for a matching param_group
for k, v in param_group.items():
if k == 'params':
continue # 'params' is not a hyper parameter, skip it
if isinstance(v, float):
optimizer_attributes_map[initializer.name][k] = v
elif isinstance(v, int):
optimizer_int_attributes_map[initializer.name][k] = v
else:
raise ValueError("Optimizer attributes must be either float or int.")

# TrainingParameters
ort_parameters = ort.TrainingParameters()
ort_parameters.loss_output_name = loss_name
ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
ort_parameters.world_rank = self.options.distributed.world_rank
ort_parameters.world_size = self.options.distributed.world_size
ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
ort_parameters.weights_to_train = trainable_params
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map

# SessionOptions
session_options = ort.SessionOptions()
session_options.use_deterministic_compute = self.options.debug.deterministic_compute

# TrainingSession
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, session_options)

# I/O bindings
self._train_io_binding = self._training_session.io_binding()
self._eval_io_binding = self._training_session.io_binding()
34 changes: 29 additions & 5 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class ORTTrainerOptions(object):
'schema' : {
'id' : {
'type' : 'string',
'nullable' : True,
'default' : None
'default' : 'cpu'
},
'mem_limit' : {
'type' : 'integer',
Expand Down Expand Up @@ -125,6 +124,17 @@ class ORTTrainerOptions(object):
}
}
},
'debug' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'deterministic_compute' : {
'type' : 'boolean',
'default' : False
},
}
},
'_internal_use' : {
'type' : 'dict',
'required': False,
Expand Down Expand Up @@ -156,7 +166,7 @@ class ORTTrainerOptions(object):
number of steps to accumulate before do collective gradient reduction
device (dict):
compute device related settings
device.id (string, default is None):
device.id (string, default is 'cpu'):
device to run training
device.mem_limit (int):
maximum memory size (in bytes) used by device.id
Expand Down Expand Up @@ -192,6 +202,10 @@ class ORTTrainerOptions(object):
list of model parameter names to skip training (weights don't change)
utils.grad_norm_clip (bool, default is False):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
forces compute to be deterministic accross runs
_internal_use (dict):
internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True):
Expand Down Expand Up @@ -325,8 +339,7 @@ def _check_is_callable(field, value, error):
'schema': {
'id': {
'type': 'string',
'nullable': True,
'default': None
'default': 'cpu'
},
'mem_limit': {
'type': 'integer',
Expand Down Expand Up @@ -408,6 +421,17 @@ def _check_is_callable(field, value, error):
}
}
},
'debug': {
'type': 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'deterministic_compute': {
'type': 'boolean',
'default': False
},
}
},
'_internal_use': {
'type': 'dict',
'default_setter': lambda _: {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def testORTTrainerOptionsDefaultValues(test_input):
'gradient_accumulation_steps': 0
},
'device': {
'id': None,
'id': 'cpu',
'mem_limit': 0
},
'distributed': {
Expand All @@ -48,6 +48,9 @@ def testORTTrainerOptionsDefaultValues(test_input):
'frozen_weights': [],
'grad_norm_clip': False
},
'debug': {
'deterministic_compute': False
},
'_internal_use': {
'enable_internal_postprocess': True,
'extra_postprocess': None,
Expand Down Expand Up @@ -81,8 +84,14 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema():
def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype):
r''' Test different ways of using default values for incomplete input'''

# Validating model description from user
model_description = md_val._ORTTrainerModelDesc(input_dict)

# Validating hard-coded learning rate description
assert model_description.learning_rate.name == "Learning_Rate"
assert model_description.learning_rate.shape == [1]
assert model_description.learning_rate.dtype == torch.float32

# Validating model description from user
for idx, i_desc in enumerate(model_description.inputs):
assert isinstance(i_desc, model_description._InputDescription)
assert len(i_desc) == 2
Expand Down Expand Up @@ -514,3 +523,4 @@ def testInstantiateORTTrainer(step_fn):
assert output_dim[dim_idx] == dim.dim_value
assert output_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)

21 changes: 21 additions & 0 deletions samples/python/pytorch_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# TransformerModel example

This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial

## Requirements

* PyTorch 1.6+
* TorchText 0.6+
* ONNX Runtime 1.5+

## Running PyTorch version

```python
python pt_model.py
```

## Running ONNX Runtime version

```python
python ort_model.py
```
3 changes: 0 additions & 3 deletions samples/python/pytorch_transformer/pt_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
'''
From https://pytorch.org/tutorials/beginner/transformer_tutorial.html
'''
import math
import torch
import torch.nn as nn
Expand Down