-
Notifications
You must be signed in to change notification settings - Fork 259
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
migrate export to 2x and 3x from deprecated (#1845)
Signed-off-by: xin3he <[email protected]>
- Loading branch information
Showing
15 changed files
with
657 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2021 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Intel Neural Compressor Export.""" | ||
|
||
from .torch2onnx import torch_to_fp32_onnx, torch_to_int8_onnx | ||
from .qlinear2qdq import onnx_qlinear_to_qdq | ||
from .tf2onnx import tf_to_fp32_onnx, tf_to_int8_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2021 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Helper functions to export onnx model from QLinear ops to QDQ.""" | ||
from neural_compressor.adaptor.ox_utils.util import find_by_name | ||
from neural_compressor.utils import logger | ||
from neural_compressor.utils.utility import LazyImport | ||
|
||
numpy_helper = LazyImport("onnx.numpy_helper") | ||
|
||
|
||
def check_model(model): | ||
"""Check optype for input model. | ||
Args: | ||
model (ModelProto): onnx model. | ||
""" | ||
has_integerop = False | ||
has_qlinearop = False | ||
for node in model.graph.node: | ||
if node.op_type.endswith("Integer"): | ||
has_integerop = True | ||
elif node.op_type.startswith("QLinear"): | ||
has_qlinearop = True | ||
elif node.op_type in ["QAttention", "QGemm", "QEmbedLayerNormalization"]: | ||
has_qlinearop = True | ||
elif node.op_type in ["Gather"]: | ||
input_data = find_by_name(node.input[0], model.graph.initializer) | ||
if input_data is not None and numpy_helper.to_array(input_data).dtype in ["int8", "uint8"]: | ||
has_qlinearop = True | ||
if has_integerop: | ||
logger.info("This model has Integer ops, these ops will be skipped.") | ||
if has_qlinearop: | ||
return True | ||
else: | ||
logger.info("This model has no QLinear ops, save the original model.") | ||
return False | ||
|
||
|
||
def onnx_qlinear_to_qdq( | ||
model, | ||
input_name_to_nodes, | ||
): | ||
"""Export ONNX QLinearops model into QDQ model. | ||
Args: | ||
model (ModelProto): int8 onnx model. | ||
input_name_to_nodes (dict): the mapping of tensor name and its destination nodes. | ||
""" | ||
from neural_compressor.adaptor.ox_utils.operators import QOPERATORS | ||
|
||
add_nodes = [] | ||
remove_nodes = [] | ||
inits = [] | ||
if check_model(model): | ||
for node in model.graph.node: | ||
if node.op_type in QOPERATORS: | ||
if node.output[0] not in input_name_to_nodes: | ||
continue | ||
children = [] | ||
for out in node.output: | ||
children.extend(input_name_to_nodes[node.output[0]]) | ||
converter = QOPERATORS[node.op_type](node, children, model.graph.initializer) | ||
done, add_node, init = converter.convert() | ||
if done: | ||
add_nodes.extend(add_node) | ||
inits.extend(init) | ||
remove_nodes.append(node) | ||
return add_nodes, remove_nodes, inits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2022 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Helper functions to export model from TensorFlow to ONNX.""" | ||
|
||
import re | ||
|
||
from neural_compressor.utils import logger | ||
from neural_compressor.utils.utility import LazyImport | ||
|
||
t2o = LazyImport("tf2onnx") | ||
|
||
|
||
def _split_nodename_and_shape(name): | ||
"""Split input name with shape into name and shape.""" | ||
# pattern for a node name | ||
inputs = [] | ||
shapes = {} | ||
# input takes in most cases the format name:0, where 0 is the output number | ||
# in some cases placeholders don't have a rank which onnx can't handle so we let uses override the shape | ||
# by appending the same, ie : [1,28,28,3] | ||
name_pattern = r"(?:([\w\d/\-\._:]+)(\[[\-\d,]+\])?),?" | ||
splits = re.split(name_pattern, name) | ||
for i in range(1, len(splits), 3): | ||
inputs.append(splits[i] + ":0") | ||
if splits[i + 1] is not None: | ||
shape = [int(n) for n in splits[i + 1][1:-1].split(",")] | ||
shape = [n if n >= 0 else None for n in shape] | ||
shapes[splits[i] + ":0"] = shape | ||
if not shapes: | ||
shapes = None | ||
return inputs, shapes | ||
|
||
|
||
def tf_to_fp32_onnx(graph_def, save_path, opset_version=14, input_names=None, output_names=None, inputs_as_nchw=None): | ||
"""Export FP32 Tensorflow model into FP32 ONNX model using tf2onnx tool. | ||
Args: | ||
graph_def (graph_def to convert): fp32 graph_def. | ||
save_path (str): save path of ONNX model. | ||
opset_version (int, optional): opset version. Defaults to 14. | ||
input_names (list, optional): input names. Defaults to None. | ||
output_names (list, optional): output names. Defaults to None. | ||
inputs_as_nchw (list, optional): transpose the input. Defaults to None. | ||
""" | ||
shape_override = None | ||
if isinstance(input_names, str): | ||
input_names, shape_override = _split_nodename_and_shape(input_names) | ||
else: | ||
input_names[:] = [o + ":0" for o in input_names] | ||
output_names[:] = [o + ":0" for o in output_names] | ||
t2o.convert.from_graph_def( | ||
graph_def=graph_def, | ||
input_names=input_names, | ||
output_names=output_names, | ||
inputs_as_nchw=inputs_as_nchw, | ||
shape_override=shape_override, | ||
opset=opset_version, | ||
output_path=save_path, | ||
) | ||
info = "The FP32 ONNX Model exported to path: {0}".format(save_path) | ||
logger.info("*" * len(info)) | ||
logger.info(info) | ||
logger.info("*" * len(info)) | ||
|
||
|
||
def tf_to_int8_onnx( | ||
int8_model, save_path, opset_version: int = 14, input_names=None, output_names=None, inputs_as_nchw=None | ||
): | ||
"""Export INT8 Tensorflow model into INT8 ONNX model. | ||
Args: | ||
int8_model (tensorflow ITEX QDQ model): int8 model. | ||
save_path (str): save path of ONNX model. | ||
opset_version (int, optional): opset version. Defaults to 14. | ||
input_names (list, optional): input names. Defaults to None. | ||
output_names (list, optional): output names. Defaults to None. | ||
inputs_as_nchw (list, optional): transpose the input. Defaults to None. | ||
""" | ||
shape_override = None | ||
if isinstance(input_names, str): | ||
input_names, shape_override = _split_nodename_and_shape(input_names) | ||
else: | ||
input_names[:] = [o + ":0" for o in input_names] | ||
output_names[:] = [o + ":0" for o in output_names] | ||
onnx_convert_graph = "./converted_graph.onnx" | ||
from neural_compressor.adaptor.tf_utils.tf2onnx_converter import TensorflowQDQToOnnxQDQConverter | ||
|
||
TensorflowQDQToOnnxQDQConverter( | ||
int8_model, input_names, output_names, shape_override, inputs_as_nchw, opset_version | ||
).convert(onnx_convert_graph) | ||
|
||
import onnxruntime as ort | ||
|
||
sess_options = ort.SessionOptions() | ||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | ||
sess_options.optimized_model_filepath = save_path | ||
import onnx | ||
|
||
model = onnx.load(onnx_convert_graph) | ||
ort.InferenceSession(model.SerializeToString(), sess_options) | ||
info = "The INT8 ONNX Model is exported to path: {0}".format(save_path) | ||
logger.info("*" * len(info)) | ||
logger.info(info) | ||
logger.info("*" * len(info)) |
Oops, something went wrong.