Skip to content

Commit

Permalink
Create training session + minor improvements (#4668)
Browse files Browse the repository at this point in the history
Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
Thiago Crepaldi and Rayan Krishnan committed Aug 12, 2020
1 parent a1ea7dc commit 5f2a140
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, model_desc):
else:
self._validated['outputs'][idx] = self._OutputDescription(*output)

# Hard-code learning rate descriptor for the model
self._validated['learning_rate'] = self._InputDescriptionTyped('Learning_Rate', [1], torch.float32)

# Convert dict in object
for k, v in self._validated.items():
setattr(self, k, self._wrap(v))
Expand Down
102 changes: 89 additions & 13 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import torch
from inspect import signature

import onnxruntime as ort
from . import ORTTrainerOptions
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc
from .. import postprocess
from onnxruntime.capi._pybind_state import set_cuda_mem_limit
from onnxruntime.capi._pybind_state import set_cuda_device_id


class TrainStepInfo(object):
Expand Down Expand Up @@ -146,6 +149,13 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
self.optim_config = optim_config
self.options = options

# Set GPU device and memory limit
device_id = self.options.device.id
if 'cuda' in device_id.lower():
set_cuda_mem_limit(int(self.options.device.mem_limit))
if ':' in device_id:
set_cuda_device_id(device_id.split(':')[1])

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Expand All @@ -162,7 +172,6 @@ def eval_step(self, *input, **kwargs):
self.model_desc.inputs, None, None, *input, **kwargs)
self._init_onnx_model(sample_input)


def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
Expand Down Expand Up @@ -260,18 +269,16 @@ def forward(self, *inputs):

return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
def _convert_torch_model_loss_fn_to_onnx(self, inputs):
device = torch.device(self.options.device.id)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(
device=device) for k in self.model_desc.inputs]
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(
inputs) if i < len(self.model_desc.inputs)]
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
else:
raise RuntimeError(
"Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")

# PyTorch ONNX exporter does not match argument names
# This is an issue because the ONNX graph depends on all inputs to be specified
Expand Down Expand Up @@ -341,7 +348,7 @@ def _init_onnx_model(self, inputs):
self.options.utils.frozen_weights.extend(torch_buffers)

# Export to ONNX
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(torch.device('cpu'), inputs)
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs)

self._init_session()

Expand All @@ -356,18 +363,87 @@ def _init_session(self):
# Perform user-specified post-processing
if self.options._internal_use.extra_postprocess:
self.options._internal_use.extra_postprocess(self._onnx_model)

# Create training session used by train_step
self._create_ort_training_session()
return

def _prepare_model_input(self, inputs_desc, lr, loss_scale, *args, **kwargs):
def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):
# Normalize input to tuple of samples
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list:
input = tuple(inputs[0])
else:
input = args
input = inputs

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)

return input

# TODO: Test this througly along with train step, including
# various optimizer parameter groups, frozen weights, loss and lr
def _create_ort_training_session(self):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
if n not in [i.name for i in self._onnx_model.graph.initializer]]
if unused_frozen_weights:
raise RuntimeError("{} params from 'frozen_weights' not found in the ONNX model.".format(
unused_frozen_weights))

# Get loss name from model description
loss_name = [item.name for item in self.model_desc.outputs if len(item) == 4 and item[2]]
assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)"
loss_name = loss_name[0]

# Parse optimizer parameters
optimizer_attributes_map = {}
optimizer_int_attributes_map = {}
trainable_params = set()
for initializer in self._onnx_model.graph.initializer:
if initializer.name in self.options.utils.frozen_weights:
continue # only trainable parameters are passed to the backend
trainable_params.add(initializer.name)
optimizer_attributes_map[initializer.name] = {}
optimizer_int_attributes_map[initializer.name] = {}
for param_group in self.optim_config.params:
if initializer.name not in param_group['params']:
continue # keep looking for a matching param_group
for k, v in param_group.items():
if k == 'params':
continue # 'params' is not a hyper parameter, skip it
if isinstance(v, float):
optimizer_attributes_map[initializer.name][k] = v
elif isinstance(v, int):
optimizer_int_attributes_map[initializer.name][k] = v
else:
raise ValueError("Optimizer attributes must be either float or int.")

# TrainingParameters
ort_parameters = ort.TrainingParameters()
ort_parameters.loss_output_name = loss_name
ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
ort_parameters.world_rank = self.options.distributed.world_rank
ort_parameters.world_size = self.options.distributed.world_size
ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
ort_parameters.weights_to_train = trainable_params
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map

# SessionOptions
session_options = ort.SessionOptions()
session_options.use_deterministic_compute = self.options.debug.deterministic_compute

# TrainingSession
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, session_options)

# I/O bindings
self._train_io_binding = self._training_session.io_binding()
self._eval_io_binding = self._training_session.io_binding()
34 changes: 29 additions & 5 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class ORTTrainerOptions(object):
'schema' : {
'id' : {
'type' : 'string',
'nullable' : True,
'default' : None
'default' : 'cpu'
},
'mem_limit' : {
'type' : 'integer',
Expand Down Expand Up @@ -125,6 +124,17 @@ class ORTTrainerOptions(object):
}
}
},
'debug' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'deterministic_compute' : {
'type' : 'boolean',
'default' : False
},
}
},
'_internal_use' : {
'type' : 'dict',
'required': False,
Expand Down Expand Up @@ -156,7 +166,7 @@ class ORTTrainerOptions(object):
number of steps to accumulate before do collective gradient reduction
device (dict):
compute device related settings
device.id (string, default is None):
device.id (string, default is 'cpu'):
device to run training
device.mem_limit (int):
maximum memory size (in bytes) used by device.id
Expand Down Expand Up @@ -192,6 +202,10 @@ class ORTTrainerOptions(object):
list of model parameter names to skip training (weights don't change)
utils.grad_norm_clip (bool, default is False):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
forces compute to be deterministic accross runs
_internal_use (dict):
internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True):
Expand Down Expand Up @@ -325,8 +339,7 @@ def _check_is_callable(field, value, error):
'schema': {
'id': {
'type': 'string',
'nullable': True,
'default': None
'default': 'cpu'
},
'mem_limit': {
'type': 'integer',
Expand Down Expand Up @@ -408,6 +421,17 @@ def _check_is_callable(field, value, error):
}
}
},
'debug': {
'type': 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'deterministic_compute': {
'type': 'boolean',
'default': False
},
}
},
'_internal_use': {
'type': 'dict',
'default_setter': lambda _: {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def testORTTrainerOptionsDefaultValues(test_input):
'gradient_accumulation_steps': 0
},
'device': {
'id': None,
'id': 'cpu',
'mem_limit': 0
},
'distributed': {
Expand All @@ -48,6 +48,9 @@ def testORTTrainerOptionsDefaultValues(test_input):
'frozen_weights': [],
'grad_norm_clip': False
},
'debug': {
'deterministic_compute': False
},
'_internal_use': {
'enable_internal_postprocess': True,
'extra_postprocess': None,
Expand Down Expand Up @@ -81,8 +84,14 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema():
def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype):
r''' Test different ways of using default values for incomplete input'''

# Validating model description from user
model_description = md_val._ORTTrainerModelDesc(input_dict)

# Validating hard-coded learning rate description
assert model_description.learning_rate.name == "Learning_Rate"
assert model_description.learning_rate.shape == [1]
assert model_description.learning_rate.dtype == torch.float32

# Validating model description from user
for idx, i_desc in enumerate(model_description.inputs):
assert isinstance(i_desc, model_description._InputDescription)
assert len(i_desc) == 2
Expand Down Expand Up @@ -514,3 +523,4 @@ def testInstantiateORTTrainer(step_fn):
assert output_dim[dim_idx] == dim.dim_value
assert output_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)

21 changes: 21 additions & 0 deletions samples/python/pytorch_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# TransformerModel example

This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial

## Requirements

* PyTorch 1.6+
* TorchText 0.6+
* ONNX Runtime 1.5+

## Running PyTorch version

```python
python pt_model.py
```

## Running ONNX Runtime version

```python
python ort_model.py
```
3 changes: 0 additions & 3 deletions samples/python/pytorch_transformer/pt_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
'''
From https://pytorch.org/tutorials/beginner/transformer_tutorial.html
'''
import math
import torch
import torch.nn as nn
Expand Down

0 comments on commit 5f2a140

Please sign in to comment.