From 2a0fcb6e71ff1ec937079e1e7f577eb500123a20 Mon Sep 17 00:00:00 2001 From: sanjaypavo <93761297+sanjaypavo@users.noreply.github.com> Date: Tue, 7 Jun 2022 11:51:17 +0530 Subject: [PATCH] changing the onnxwrapper script for gpu issue (#532) * changing the onnxwrapper script * gpu_issue * Update wrapper.py * Update wrapper.py * Update runtime.txt * Update runtime.txt * Update wrapper.py --- mmdeploy/backend/onnxruntime/wrapper.py | 11 +++++------ requirements/runtime.txt | 1 + 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmdeploy/backend/onnxruntime/wrapper.py b/mmdeploy/backend/onnxruntime/wrapper.py index 4239853e2d..daac6bf515 100644 --- a/mmdeploy/backend/onnxruntime/wrapper.py +++ b/mmdeploy/backend/onnxruntime/wrapper.py @@ -50,9 +50,9 @@ def __init__(self, logger.warning(f'The library of onnxruntime custom ops does \ not exist: {ort_custom_op_path}') device_id = parse_device_id(device) - is_cuda_available = ort.get_device() == 'GPU' - providers = [('CUDAExecutionProvider', {'device_id': device_id})] \ - if is_cuda_available else ['CPUExecutionProvider'] + providers = ['CPUExecutionProvider'] \ + if device == 'cpu' else \ + [('CUDAExecutionProvider', {'device_id': device_id})] sess = ort.InferenceSession( onnx_file, session_options, providers=providers) if output_names is None: @@ -60,8 +60,7 @@ def __init__(self, self.sess = sess self.io_binding = sess.io_binding() self.device_id = device_id - self.is_cuda_available = is_cuda_available - self.device_type = 'cuda' if is_cuda_available else 'cpu' + self.device_type = 'cpu' if device == 'cpu' else 'cuda' super().__init__(output_names) def forward(self, inputs: Dict[str, @@ -77,7 +76,7 @@ def forward(self, inputs: Dict[str, for name, input_tensor in inputs.items(): # set io binding for inputs/outputs input_tensor = input_tensor.contiguous() - if not self.is_cuda_available: + if self.device_type == 'cpu': input_tensor = input_tensor.cpu() # Avoid unnecessary data transfer between host and device element_type = input_tensor.new_zeros( diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 6114dfc58f..aa7aec20ea 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,5 +2,6 @@ h5py matplotlib numpy onnx>=1.8.0 +protobuf==3.20.0 six terminaltables