Skip to content

Commit

Permalink
Refactor name, shape and dtype for inputs/outputs on ONNX models
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Jul 29, 2020
1 parent 1757ed5 commit 25b4adf
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 91 deletions.
91 changes: 91 additions & 0 deletions orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,97 @@
import importlib.util
import os
import sys
import torch


def dtype_torch_to_numpy(torch_dtype):
'''Converts PyTorch types to Numpy types
Also must map to types accepted by:
MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type)
References:
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html
https://pytorch.org/docs/stable/tensors.html
'''
if torch_dtype == torch.float64 or torch_dtype == torch.double:
return np.float64
elif torch_dtype == torch.float32 or torch_dtype == torch.float:
return np.float32
elif torch_dtype == torch.float16 or torch_dtype == torch.half or torch_dtype == torch.bfloat16:
# NOTE: numpy doesn't support bfloat16
return np.float16
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
return np.int64
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
return np.int32
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
return np.int16
elif torch_dtype == torch.int8:
return np.int8
elif torch_dtype == torch.uint8:
return np.uint8
elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64:
# NOTE: numpy doesn't support complex32
return np.complex64
elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble:
return np.complex128
elif torch_dtype == torch.bool:
return np.bool_
else:
raise ValueError(
f'torch_dtype ({str(torch_dtype)}) type is not supported by Numpy')


def dtype_onnx_to_torch(onnx_type):
'''Converts ONNX types to PyTorch types
Reference: https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto (enum DataType)
https://pytorch.org/docs/stable/tensors.html
'''
onnx_types = ['UNDEFINED', 'FLOAT', 'UINT8', 'INT8', 'UINT16', 'INT16', 'INT32', 'INT64', 'STRING',
'BOOL', 'FLOAT16', 'DOUBLE', 'UINT32', 'UINT64', 'COMPLEX64', 'COMPLEX128', 'BFLOAT16']

if isinstance(onnx_type, int):
assert onnx_type < len(onnx_types), "Invalid onnx_type integer"
elif isinstance(onnx_type, str):
onnx_type = onnx_type.upper()
assert onnx_type in onnx_types, "Invalid onnx_type string"
onnx_type = onnx_types.index(onnx_type)
else:
raise ValueError(
"'onnx_type' must be an ONNX type represented by either a string or integer")

if onnx_type == 0:
return None
elif onnx_type == 1:
return torch.float
elif onnx_type >= 2 and onnx_type <= 3:
# NOTE: Pytorch doesn't support uint8
return torch.int8
elif onnx_type >= 4 and onnx_type <= 5:
# NOTE: Pytorch doesn't support int16
return torch.int16
elif onnx_type == 6 or onnx_type == 12:
# NOTE: Pytorch doesn't support uint32
return torch.int32
elif onnx_type == 7 or onnx_type == 13:
# NOTE: Pytorch doesn't support uint64
return torch.int64
elif onnx_type == 8:
return str
elif onnx_type == 9:
return torch.bool
elif onnx_type == 10:
return torch.float16
elif onnx_type == 11:
return torch.double
elif onnx_type == 14:
return torch.complex64
elif onnx_type == 15:
return torch.complex128
elif onnx_type == 16:
return torch.bfloat


def static_vars(**kwargs):
Expand Down
47 changes: 23 additions & 24 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import io
import onnx
import torch

from distutils.version import LooseVersion
from inspect import signature

import onnxruntime.capi.postprocess as postprocess
from . import ORTTrainerOptions
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc
from .. import postprocess


class TrainStepInfo(object):
Expand Down Expand Up @@ -146,11 +144,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):

self.model_desc = _ORTTrainerModelDesc(model_desc)
self.optim_config = optim_config

if options:
self.options = ORTTrainerOptions(options)
else:
self.options = ORTTrainerOptions()
self.options = options

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Expand All @@ -162,7 +156,12 @@ def eval_step(self, *input, **kwargs):
Returns:
ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc`
"""
pass
# Export model to ONNX
if self._onnx_model is None:
sample_input = self._prepare_model_input(
self.model_desc.inputs, None, None, *input, **kwargs)
self._init_onnx_model(sample_input)


def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
Expand Down Expand Up @@ -192,11 +191,11 @@ def train_step(self, *input, **kwargs):

# Export model to ONNX
if self._onnx_model is None:
sample_input = self._prepare_input_and_fetches(
sample_input = self._prepare_model_input(
self.model_desc.inputs, None, None, *input, **kwargs)
self._init_onnx_model(sample_input)

def _combine_torch_model_with_loss(self):
def _combine_torch_model_with_loss_fn(self):
# Don't need to wrap model when loss_fn is not set
if not self.loss_fn:
return self._torch_model
Expand Down Expand Up @@ -236,12 +235,11 @@ def __init__(self, model, loss_fn, input_names):
self.input_names = input_names

def forward(self, *inputs):
# *inputs is given by torch trace. It is in the order of input_names.
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
# '*inputs' is given by torch trace and matches the order of 'input_names'
# The 'model' input might differ from 'input_names'
if is_list_input:
input, label = inputs[:-1], inputs[-1]
preds = self.model(*input)
# TODO: order this according to self.model_desc.outputs?
return self.loss_fn(preds, label), preds
else:
sig = signature(self.model.forward)
Expand Down Expand Up @@ -275,10 +273,9 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
raise RuntimeError(
"Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")

# pytorch onnx exporter/trace does not try to match argument names.
# e.g. for models with optional inputs, it requires all inputs be present.
# this is a problem because the model graph depends on inputs provided.
model = self._combine_torch_model_with_loss()
# PyTorch ONNX exporter does not match argument names
# This is an issue because the ONNX graph depends on all inputs to be specified
model = self._combine_torch_model_with_loss_fn()

# Do an inference to grab output types
model.eval()
Expand All @@ -288,7 +285,7 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
if isinstance(sample_outputs, torch.Tensor):
sample_outputs = [sample_outputs]

# Append 'dtypes' for model description inputs/outputs
# Append 'dtype' for model description's inputs/outputs
for i, sample_input in enumerate(sample_inputs):
if i < len(self.model_desc.inputs):
self.model_desc.add_type_to_input_description(
Expand All @@ -310,7 +307,7 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
training=torch.onnx.TrainingMode.TRAINING)
onnx_model = onnx.load_model_from_string(f.getvalue())

# Remove 'model.' prefix introduced by model wrapper for initializers.
# Remove 'model.' prefix introduced by CombineTorchModelLossFn class
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model.'):
Expand All @@ -321,8 +318,8 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]

# onnx model initializer may contain non-trainable registered buffers that are not part
# of pytorch model named parameteres.
# ONNX model initializers may contain non-trainable registered buffers
# that are not part of PyTorch model named parameteres
named_parameters = model.model.named_parameters() if hasattr(model, 'model') else model.named_parameters()
assert set([n for n, t in named_parameters]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
Expand Down Expand Up @@ -352,21 +349,23 @@ def _init_session(self):
if self._onnx_model is None:
return

# Perform internal post-processing
if self.options._internal_use.enable_internal_postprocess:
self._onnx_model = postprocess.run_postprocess(self._onnx_model)

# Perform user-specified post-processing
if self.options._internal_use.extra_postprocess:
self.options._internal_use.extra_postprocess(self._onnx_model)
return

def _prepare_input_and_fetches(self, inputs_desc, lr, loss_scale, *args, **kwargs):
def _prepare_model_input(self, inputs_desc, lr, loss_scale, *args, **kwargs):
# Normalize input to tuple of samples
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
else:
input = args

# Append input from kwargs
# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from onnxruntime.capi.training import orttrainer_options as orttrainer_options
from onnxruntime.capi.training import model_desc_validation as md_val
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils
#from pt_model import TransformerModel


@pytest.mark.parametrize("test_input", [
Expand Down Expand Up @@ -449,17 +448,22 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
expected_values[step], rtol=rtol, err_msg="lr mismatch")


def testInstantiateORTTrainer():

# Loading external samples for testing
pytorch_transformer_path = os.path.join('..','..','..','samples','python','pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path,'pt_model.py')
@pytest.mark.parametrize("step_fn", [
('train_step'),
('eval_step')
])
def testInstantiateORTTrainer(step_fn):
# Loading external TransformerModel model for testing
# A manual import is done as this example is not part of onnxruntime package,
# but resides on the onnxruntime repo
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
pt_model_name = 'pt_model'
pt_model = _utils.import_module_from_file(pt_model_path, pt_model_name)
ort_utils_path = os.path.join(pytorch_transformer_path,'ort_utils.py')
ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
ort_utils_name = 'ort_utils'
ort_utils = _utils.import_module_from_file(ort_utils_path, ort_utils_name)
utils_path = os.path.join(pytorch_transformer_path,'utils.py')
utils_path = os.path.join(pytorch_transformer_path, 'utils.py')
utils_name = 'utils'
utils = _utils.import_module_from_file(utils_path, utils_name)

Expand All @@ -468,73 +472,45 @@ def testInstantiateORTTrainer():
my_loss = ort_utils.my_loss
model_desc = ort_utils.transformer_model_description()
optim_config = optim.LambConfig()
options = orttrainer.ORTTrainerOptions({'device': {'id': 'cpu'}})

# Create ORTTrainer
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=None)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

# Prep data
# Preparing data
train_data, val_data, test_data = utils.prepare_data('cpu', 20, 20)

# Train
for batch, i in enumerate(range(0, train_data.size(0)-1, 35)):
data, targets = utils.get_batch(train_data, i)
learning_rate = 0.001
trainer.train_step(data, targets, learning_rate) # removed learning rate here and in model desc
break

# Export model to ONNX
data, targets = utils.get_batch(train_data, 0)
if step_fn == 'eval_step':
trainer.eval_step(data, targets)
elif step_fn == 'train_step':
trainer.train_step(data, targets)
else:
raise ValueError('Invalid step_fn')
assert trainer._onnx_model is not None
onnx_model = trainer._onnx_model

# Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs
sig = inspect.signature(model.forward)
sig_loss = inspect.signature(my_loss)

# element 4 should be uint16, but this is not in pytorch
# element 8 should be string
int_to_type = [None, torch.float32, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int64, type("test"), torch.bool]

# check that the first len(forward.parameters) inputs have the same name, dimensions and dtype
in_str = str(onnx_model.graph.input)
for i in range(len(sig.parameters.keys())):
input_name = model_desc['inputs'][i][0]
input_dim = model_desc['inputs'][i][1]
input_name = trainer.model_desc.inputs[i][0]
input_dim = trainer.model_desc.inputs[i][1]
input_type = trainer.model_desc.inputs[i][2]

assert in_str.find(input_name) >= 0
start_index = in_str.index(input_name)
end_index = in_str.index("name", start_index+1) if i+1 < in_str.find("name", start_index+1)>0 else in_str.index(']')
sub = in_str[in_str.index(input_name):end_index]
dims = []
elem_type = 0
for item in sub.split("\n"):
if item.find("dim_value:")>0:
temp = item.replace(" ", "").replace("dim_value:", "")
dims.append(int(temp))
if item.find("elem_type:")>0:
temp = item.replace(" ", "").replace("elem_type:", "")
elem_type = int(temp)
assert int_to_type[elem_type] == input_type
assert dims == input_dim

# check that all the outputs of model desc match the name, dimensions and dtype of the ort graph
out_str = str(onnx_model.graph.output)
for i in range(len(model_desc['outputs'])):
output_name = model_desc['outputs'][i][0]
output_dim = model_desc['outputs'][i][1]

assert trainer._onnx_model.graph.input[i].name == input_name
for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim):
assert input_dim[dim_idx] == dim.dim_value
assert input_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.input[i].type.tensor_type.elem_type)

# Check name, shape and dtype of the ORT graph outputs
for i in range(len(trainer.model_desc.outputs)):
output_name = trainer.model_desc.outputs[i][0]
output_dim = trainer.model_desc.outputs[i][1]
output_type = trainer.model_desc.outputs[i][3]

assert out_str.find(output_name) >= 0
start_index = out_str.index(output_name)
end_index = out_str.index("name", start_index+1) if i+1 < out_str.find("name", start_index+1)>0 else out_str.index(']')
sub = out_str[out_str.index(output_name):end_index]
dims = []
elem_type = 0
for item in sub.split("\n"):
if item.find("dim_value:")>0:
temp = item.replace(" ", "").replace("dim_value:", "")
dims.append(int(temp))
if item.find("elem_type:")>0:
temp = item.replace(" ", "").replace("elem_type:", "")
elem_type = int(temp)
assert int_to_type[elem_type] == output_type
assert dims == output_dim


assert trainer._onnx_model.graph.output[i].name == output_name
for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim):
assert output_dim[dim_idx] == dim.dim_value
assert output_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)

0 comments on commit 25b4adf

Please sign in to comment.