From 25b4adf7ab271ca557dda8f10a7274fd2cff3192 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 29 Jul 2020 15:31:28 -0700 Subject: [PATCH] Refactor name, shape and dtype for inputs/outputs on ONNX models --- .../orttraining/python/training/_utils.py | 91 +++++++++++++++ .../orttraining/python/training/orttrainer.py | 47 ++++---- .../orttraining_test_orttrainer_frontend.py | 110 +++++++----------- 3 files changed, 157 insertions(+), 91 deletions(-) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 513b4ebddcb89..8ebe82b477c54 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -1,6 +1,97 @@ import importlib.util import os import sys +import torch + + +def dtype_torch_to_numpy(torch_dtype): + '''Converts PyTorch types to Numpy types + + Also must map to types accepted by: + MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) + + References: + https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html + https://pytorch.org/docs/stable/tensors.html + ''' + if torch_dtype == torch.float64 or torch_dtype == torch.double: + return np.float64 + elif torch_dtype == torch.float32 or torch_dtype == torch.float: + return np.float32 + elif torch_dtype == torch.float16 or torch_dtype == torch.half or torch_dtype == torch.bfloat16: + # NOTE: numpy doesn't support bfloat16 + return np.float16 + elif torch_dtype == torch.int64 or torch_dtype == torch.long: + return np.int64 + elif torch_dtype == torch.int32 or torch_dtype == torch.int: + return np.int32 + elif torch_dtype == torch.int16 or torch_dtype == torch.short: + return np.int16 + elif torch_dtype == torch.int8: + return np.int8 + elif torch_dtype == torch.uint8: + return np.uint8 + elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64: + # NOTE: numpy doesn't support complex32 + return np.complex64 + elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble: + return np.complex128 + elif torch_dtype == torch.bool: + return np.bool_ + else: + raise ValueError( + f'torch_dtype ({str(torch_dtype)}) type is not supported by Numpy') + + +def dtype_onnx_to_torch(onnx_type): + '''Converts ONNX types to PyTorch types + + Reference: https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto (enum DataType) + https://pytorch.org/docs/stable/tensors.html + ''' + onnx_types = ['UNDEFINED', 'FLOAT', 'UINT8', 'INT8', 'UINT16', 'INT16', 'INT32', 'INT64', 'STRING', + 'BOOL', 'FLOAT16', 'DOUBLE', 'UINT32', 'UINT64', 'COMPLEX64', 'COMPLEX128', 'BFLOAT16'] + + if isinstance(onnx_type, int): + assert onnx_type < len(onnx_types), "Invalid onnx_type integer" + elif isinstance(onnx_type, str): + onnx_type = onnx_type.upper() + assert onnx_type in onnx_types, "Invalid onnx_type string" + onnx_type = onnx_types.index(onnx_type) + else: + raise ValueError( + "'onnx_type' must be an ONNX type represented by either a string or integer") + + if onnx_type == 0: + return None + elif onnx_type == 1: + return torch.float + elif onnx_type >= 2 and onnx_type <= 3: + # NOTE: Pytorch doesn't support uint8 + return torch.int8 + elif onnx_type >= 4 and onnx_type <= 5: + # NOTE: Pytorch doesn't support int16 + return torch.int16 + elif onnx_type == 6 or onnx_type == 12: + # NOTE: Pytorch doesn't support uint32 + return torch.int32 + elif onnx_type == 7 or onnx_type == 13: + # NOTE: Pytorch doesn't support uint64 + return torch.int64 + elif onnx_type == 8: + return str + elif onnx_type == 9: + return torch.bool + elif onnx_type == 10: + return torch.float16 + elif onnx_type == 11: + return torch.double + elif onnx_type == 14: + return torch.complex64 + elif onnx_type == 15: + return torch.complex128 + elif onnx_type == 16: + return torch.bfloat def static_vars(**kwargs): diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 839ca833eea95..e171c60f2986a 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -1,14 +1,12 @@ import io import onnx import torch - -from distutils.version import LooseVersion from inspect import signature -import onnxruntime.capi.postprocess as postprocess from . import ORTTrainerOptions from . import optim from .model_desc_validation import _ORTTrainerModelDesc +from .. import postprocess class TrainStepInfo(object): @@ -146,11 +144,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): self.model_desc = _ORTTrainerModelDesc(model_desc) self.optim_config = optim_config - - if options: - self.options = ORTTrainerOptions(options) - else: - self.options = ORTTrainerOptions() + self.options = options def eval_step(self, *input, **kwargs): r"""Evaluation step method @@ -162,7 +156,12 @@ def eval_step(self, *input, **kwargs): Returns: ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` """ - pass + # Export model to ONNX + if self._onnx_model is None: + sample_input = self._prepare_model_input( + self.model_desc.inputs, None, None, *input, **kwargs) + self._init_onnx_model(sample_input) + def save_as_onnx(self, path): r"""Persists ONNX model into :py:attr:`path` @@ -192,11 +191,11 @@ def train_step(self, *input, **kwargs): # Export model to ONNX if self._onnx_model is None: - sample_input = self._prepare_input_and_fetches( + sample_input = self._prepare_model_input( self.model_desc.inputs, None, None, *input, **kwargs) self._init_onnx_model(sample_input) - def _combine_torch_model_with_loss(self): + def _combine_torch_model_with_loss_fn(self): # Don't need to wrap model when loss_fn is not set if not self.loss_fn: return self._torch_model @@ -236,12 +235,11 @@ def __init__(self, model, loss_fn, input_names): self.input_names = input_names def forward(self, *inputs): - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. + # '*inputs' is given by torch trace and matches the order of 'input_names' + # The 'model' input might differ from 'input_names' if is_list_input: input, label = inputs[:-1], inputs[-1] preds = self.model(*input) - # TODO: order this according to self.model_desc.outputs? return self.loss_fn(preds, label), preds else: sig = signature(self.model.forward) @@ -275,10 +273,9 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs): raise RuntimeError( "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - # pytorch onnx exporter/trace does not try to match argument names. - # e.g. for models with optional inputs, it requires all inputs be present. - # this is a problem because the model graph depends on inputs provided. - model = self._combine_torch_model_with_loss() + # PyTorch ONNX exporter does not match argument names + # This is an issue because the ONNX graph depends on all inputs to be specified + model = self._combine_torch_model_with_loss_fn() # Do an inference to grab output types model.eval() @@ -288,7 +285,7 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs): if isinstance(sample_outputs, torch.Tensor): sample_outputs = [sample_outputs] - # Append 'dtypes' for model description inputs/outputs + # Append 'dtype' for model description's inputs/outputs for i, sample_input in enumerate(sample_inputs): if i < len(self.model_desc.inputs): self.model_desc.add_type_to_input_description( @@ -310,7 +307,7 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs): training=torch.onnx.TrainingMode.TRAINING) onnx_model = onnx.load_model_from_string(f.getvalue()) - # Remove 'model.' prefix introduced by model wrapper for initializers. + # Remove 'model.' prefix introduced by CombineTorchModelLossFn class replace_name_dict = {} for n in onnx_model.graph.initializer: if n.name.startswith('model.'): @@ -321,8 +318,8 @@ def _convert_torch_model_loss_fn_to_onnx(self, device, inputs): if name in replace_name_dict: n.input[i] = replace_name_dict[name] - # onnx model initializer may contain non-trainable registered buffers that are not part - # of pytorch model named parameteres. + # ONNX model initializers may contain non-trainable registered buffers + # that are not part of PyTorch model named parameteres named_parameters = model.model.named_parameters() if hasattr(model, 'model') else model.named_parameters() assert set([n for n, t in named_parameters]).issubset( set([n.name for n in onnx_model.graph.initializer])), \ @@ -352,21 +349,23 @@ def _init_session(self): if self._onnx_model is None: return + # Perform internal post-processing if self.options._internal_use.enable_internal_postprocess: self._onnx_model = postprocess.run_postprocess(self._onnx_model) + # Perform user-specified post-processing if self.options._internal_use.extra_postprocess: self.options._internal_use.extra_postprocess(self._onnx_model) return - def _prepare_input_and_fetches(self, inputs_desc, lr, loss_scale, *args, **kwargs): + def _prepare_model_input(self, inputs_desc, lr, loss_scale, *args, **kwargs): # Normalize input to tuple of samples if type(args) == tuple and len(args) == 1 and type(args[0]) == list: input = tuple(args[0]) else: input = args - # Append input from kwargs + # Append input from 'kwargs' for input_desc in inputs_desc: if input_desc[0] in kwargs: input = input + (kwargs[input_desc[0]],) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 835dd1919621e..8d355616b1140 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -9,7 +9,6 @@ from onnxruntime.capi.training import orttrainer_options as orttrainer_options from onnxruntime.capi.training import model_desc_validation as md_val from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils -#from pt_model import TransformerModel @pytest.mark.parametrize("test_input", [ @@ -449,17 +448,22 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): expected_values[step], rtol=rtol, err_msg="lr mismatch") -def testInstantiateORTTrainer(): - - # Loading external samples for testing - pytorch_transformer_path = os.path.join('..','..','..','samples','python','pytorch_transformer') - pt_model_path = os.path.join(pytorch_transformer_path,'pt_model.py') +@pytest.mark.parametrize("step_fn", [ + ('train_step'), + ('eval_step') +]) +def testInstantiateORTTrainer(step_fn): + # Loading external TransformerModel model for testing + # A manual import is done as this example is not part of onnxruntime package, + # but resides on the onnxruntime repo + pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer') + pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py') pt_model_name = 'pt_model' pt_model = _utils.import_module_from_file(pt_model_path, pt_model_name) - ort_utils_path = os.path.join(pytorch_transformer_path,'ort_utils.py') + ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py') ort_utils_name = 'ort_utils' ort_utils = _utils.import_module_from_file(ort_utils_path, ort_utils_name) - utils_path = os.path.join(pytorch_transformer_path,'utils.py') + utils_path = os.path.join(pytorch_transformer_path, 'utils.py') utils_name = 'utils' utils = _utils.import_module_from_file(utils_path, utils_name) @@ -468,73 +472,45 @@ def testInstantiateORTTrainer(): my_loss = ort_utils.my_loss model_desc = ort_utils.transformer_model_description() optim_config = optim.LambConfig() + options = orttrainer.ORTTrainerOptions({'device': {'id': 'cpu'}}) # Create ORTTrainer - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=None) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Prep data + # Preparing data train_data, val_data, test_data = utils.prepare_data('cpu', 20, 20) - # Train - for batch, i in enumerate(range(0, train_data.size(0)-1, 35)): - data, targets = utils.get_batch(train_data, i) - learning_rate = 0.001 - trainer.train_step(data, targets, learning_rate) # removed learning rate here and in model desc - break - + # Export model to ONNX + data, targets = utils.get_batch(train_data, 0) + if step_fn == 'eval_step': + trainer.eval_step(data, targets) + elif step_fn == 'train_step': + trainer.train_step(data, targets) + else: + raise ValueError('Invalid step_fn') assert trainer._onnx_model is not None - onnx_model = trainer._onnx_model + # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs sig = inspect.signature(model.forward) - sig_loss = inspect.signature(my_loss) - - # element 4 should be uint16, but this is not in pytorch - # element 8 should be string - int_to_type = [None, torch.float32, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int64, type("test"), torch.bool] - - # check that the first len(forward.parameters) inputs have the same name, dimensions and dtype - in_str = str(onnx_model.graph.input) for i in range(len(sig.parameters.keys())): - input_name = model_desc['inputs'][i][0] - input_dim = model_desc['inputs'][i][1] + input_name = trainer.model_desc.inputs[i][0] + input_dim = trainer.model_desc.inputs[i][1] input_type = trainer.model_desc.inputs[i][2] - - assert in_str.find(input_name) >= 0 - start_index = in_str.index(input_name) - end_index = in_str.index("name", start_index+1) if i+1 < in_str.find("name", start_index+1)>0 else in_str.index(']') - sub = in_str[in_str.index(input_name):end_index] - dims = [] - elem_type = 0 - for item in sub.split("\n"): - if item.find("dim_value:")>0: - temp = item.replace(" ", "").replace("dim_value:", "") - dims.append(int(temp)) - if item.find("elem_type:")>0: - temp = item.replace(" ", "").replace("elem_type:", "") - elem_type = int(temp) - assert int_to_type[elem_type] == input_type - assert dims == input_dim - - # check that all the outputs of model desc match the name, dimensions and dtype of the ort graph - out_str = str(onnx_model.graph.output) - for i in range(len(model_desc['outputs'])): - output_name = model_desc['outputs'][i][0] - output_dim = model_desc['outputs'][i][1] + + assert trainer._onnx_model.graph.input[i].name == input_name + for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): + assert input_dim[dim_idx] == dim.dim_value + assert input_type == _utils.dtype_onnx_to_torch( + trainer._onnx_model.graph.input[i].type.tensor_type.elem_type) + + # Check name, shape and dtype of the ORT graph outputs + for i in range(len(trainer.model_desc.outputs)): + output_name = trainer.model_desc.outputs[i][0] + output_dim = trainer.model_desc.outputs[i][1] output_type = trainer.model_desc.outputs[i][3] - - assert out_str.find(output_name) >= 0 - start_index = out_str.index(output_name) - end_index = out_str.index("name", start_index+1) if i+1 < out_str.find("name", start_index+1)>0 else out_str.index(']') - sub = out_str[out_str.index(output_name):end_index] - dims = [] - elem_type = 0 - for item in sub.split("\n"): - if item.find("dim_value:")>0: - temp = item.replace(" ", "").replace("dim_value:", "") - dims.append(int(temp)) - if item.find("elem_type:")>0: - temp = item.replace(" ", "").replace("elem_type:", "") - elem_type = int(temp) - assert int_to_type[elem_type] == output_type - assert dims == output_dim - + + assert trainer._onnx_model.graph.output[i].name == output_name + for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): + assert output_dim[dim_idx] == dim.dim_value + assert output_type == _utils.dtype_onnx_to_torch( + trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)