diff --git a/setup.py b/setup.py index 8e425bd816..ef639fa786 100644 --- a/setup.py +++ b/setup.py @@ -164,6 +164,9 @@ def _parse_requirements_file(file_path): _haystack_integration_deps = _parse_requirements_file(_haystack_requirements_file_path) +_torch_deps = ["torch>=1.7.0,<=2.0"] + + def _check_supported_system(): if sys.platform.startswith("linux"): # linux is supported, allow install to go through @@ -276,6 +279,7 @@ def _setup_extras() -> Dict: "openpifpaf": _openpifpaf_integration_deps, "yolov8": _yolov8_integration_deps, "transformers": _transformers_integration_deps, + "torch": _torch_deps, } diff --git a/src/deepsparse/benchmark/__init__.py b/src/deepsparse/benchmark/__init__.py index 432d48cf44..9c4febd9f5 100644 --- a/src/deepsparse/benchmark/__init__.py +++ b/src/deepsparse/benchmark/__init__.py @@ -18,6 +18,7 @@ from .ort_engine import * from .results import * +from .torchscript_engine import * _analytics.send_event("python__benchmark__init") diff --git a/src/deepsparse/benchmark/torchscript_engine.py b/src/deepsparse/benchmark/torchscript_engine.py new file mode 100644 index 0000000000..6bba616ce6 --- /dev/null +++ b/src/deepsparse/benchmark/torchscript_engine.py @@ -0,0 +1,306 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Dict, List, Optional, Tuple, Union + +import numpy + + +try: + import torch + + torch_import_error = None +except Exception as torch_import_err: + torch_import_error = torch_import_err + torch = None + +try: + import onnxruntime + + ort_import_error = None +except Exception as ort_import_err: + onnxruntime = None + ort_import_error = ort_import_err + +try: + # flake8: noqa + from deepsparse.cpu import cpu_architecture +except ImportError: + raise ImportError( + "Unable to import deepsparse python apis. " + "Please contact support@neuralmagic.com" + ) + +__all__ = ["TorchScriptEngine"] + +_LOGGER = logging.getLogger(__name__) + +ARCH = cpu_architecture() +NUM_CORES = ARCH.num_available_physical_cores + + +def _validate_torch_import(): + if ort_import_error is not None: + raise ImportError( + "An exception occurred when importing onxxruntime. Please verify that " + "onnxruntime is installed in order to use the onnxruntime inference " + "engine.\n\n`onnxruntime` can be installed by running the command " + "`pip install deepsparse[onnxruntime]`" + f"\n\nException info: {ort_import_error}" + ) + + +def _validate_batch_size(batch_size: int) -> int: + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + return batch_size + + +def _select_device(device: str): + device = str(device).lower() + if device == "cuda": + if torch.cuda.is_available(): + return "cuda" + raise ValueError( + "Cuda not available is the local environment. Please select 'cpu'" + ) + return "cpu" + + +def _validate_jit_model(model): + if isinstance(model, torch.jit.ScriptModule): + return + raise ValueError(f"{model} is not a torch.jit model") + + +class TorchScriptEngine(object): + """ + Given a loaded Torchscript(.pt) model or its saved file path, create an + that compiles the given pytorch file, + + # Note 1: Engines are compiled for a specific batch size + + # :param model: Either a path to the model's .pt file or the loaded model + # :param batch_size: The batch size of the inputs to be used with the engine + # :param device: Hardware to run the engine on. Either cpu or cuda + """ + + def __init__( + self, + model: Union[str, "Model"], + batch_size: int = 1, + device: str = "cpu", + **kwargs, + ): + if torch is None: + raise ImportError(f"Unable to import torch, error: {torch_import_error}") + + _validate_torch_import() + + self._batch_size = _validate_batch_size(batch_size) + self._device = _select_device(device) + + if isinstance(model, torch.nn.Module): + self._model = model + else: + self._model = torch.jit.load(model).eval() + _validate_jit_model(self._model) + + self._model.to(self.device) + + def __call__( + self, + inp: List[numpy.ndarray], + val_inp: bool = True, + ) -> List[numpy.ndarray]: + """ + Convenience function for TorchScriptEngine.run(), see @run for more details + + | Example: + | engine = TorchScriptEngine("path/to/.pt", batch_size=1) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out = engine(inp) + | assert isinstance(out, List) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for inference. + :return: The list of outputs from the model after executing over the inputs + """ + return self.run(inp, val_inp) + + def __repr__(self): + """ + :return: Unambiguous representation of the current model instance + """ + return "{}({})".format(self.__class__, self._properties_dict()) + + def __str__(self): + """ + :return: Human readable form of the current model instance + """ + formatted_props = [ + "\t{}: {}".format(key, val) for key, val in self._properties_dict().items() + ] + + return "{}:\n{}".format( + self.__class__.__qualname__, + "\n".join(formatted_props), + ) + + @property + def model_path(self) -> str: + """ + :return: The local path to the model file the current instance was compiled from + """ + return self._model_path + + @property + def batch_size(self) -> int: + """ + :return: The batch size of the inputs to be used with the model + """ + return self._batch_size + + @property + def device(self) -> str: + return self._device + + @property + def scheduler(self) -> None: + """ + :return: The kind of scheduler to execute with + """ + return None + + @property + def input_names(self) -> List[str]: + """ + :return: The ordered names of the inputs. + """ + return [node_arg.name for node_arg in self._eng_net.get_inputs()] + + @property + def input_shapes(self) -> List[Tuple]: + """ + :return: The ordered shapes of the inputs. + """ + return [tuple(node_arg.shape) for node_arg in self._eng_net.get_inputs()] + + @property + def output_names(self) -> List[str]: + """ + :return: The ordered names of the outputs. + """ + return [node_arg.name for node_arg in self._eng_net.get_outputs()] + + def run( + self, + inp: List[numpy.ndarray], + val_inp: bool = True, + ) -> List[numpy.ndarray]: + """ + Run given inputs through the model for inference. + Returns the result as a list of numpy arrays corresponding to + the outputs of the model as defined in the ONNX graph. + + Note 1: the input dimensions must match what is defined in the torch.nn.Module. + To avoid extra time in memory shuffles, the best use case + is to format both the Module and the input into channels first format; + ex: [batch, height, width, channels] => [batch, channels, height, width] + + Note 2: the input type for the numpy arrays must match + what is defined in torch.nn.Module + Generally float32 is most common, + but int8 and int16 are used for certain layer and input types + such as with quantized models. + + Note 3: the numpy arrays must be contiguous in memory, + use numpy.ascontiguousarray(array) to fix if not. + + | Example: + | engine = TorchScriptEngine("path/to/.pt", batch_size=1, num_cores=None) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out = engine.run(inp) + | assert isinstance(out, List) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for inference. + :return: The list of outputs from the model after executing over the inputs + """ + + torch_inputs = [torch.from_numpy(input).to(self.device) for input in inp] + + tensors = self._model(*torch_inputs) + + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + return [tensor.cpu().detach().numpy() for tensor in tensors] + + def timed_run( + self, inp: List[numpy.ndarray], val_inp: bool = False + ) -> Tuple[List[numpy.ndarray], float]: + """ + Convenience method for timing a model inference run. + Returns the result as a tuple containing (the outputs from @run, time take) + See @run for more details. + + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for inference. + :return: The list of outputs from the model after executing over the inputs + """ + start = time.perf_counter() + out = self.run(inp, val_inp) + end = time.perf_counter() + + return out, end - start + + def mapped_run( + self, + inp: List[numpy.ndarray], + val_inp: bool = True, + ) -> Dict[str, numpy.ndarray]: + """ + Run given inputs through the model for inference. + Returns the result as a dictionary of numpy arrays corresponding to + the output names of the model as defined in the ONNX graph. + + Note 1: this function can add some a performance hit in certain cases. + If using, please validate that you do not incur a performance hit + by comparing with the regular run func + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for inference. + :return: The dictionary of outputs from the model after executing + over the inputs + """ + out = self.run(inp, val_inp) + return zip(self._output_names, out) + + def _properties_dict(self) -> Dict: + return { + "model_file_path": self.model_path, + "batch_size": self.batch_size, + } diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 6ae0b59b1a..d55a5d138d 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -73,6 +73,7 @@ def __init__( self, *, class_names: Union[None, str, Dict[str, str]] = None, + image_size: Optional[Tuple[int]] = None, top_k: int = 1, **kwargs, ): @@ -85,7 +86,7 @@ def __init__( else: self._class_names = None - self._image_size = self._infer_image_size() + self._image_size = image_size or self._infer_image_size() self.top_k = top_k # torchvision transforms for raw inputs diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 8a88dacbca..765e6ff413 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -27,7 +27,7 @@ from deepsparse import Context, Engine, MultiModelEngine, Scheduler from deepsparse.base_pipeline import _REGISTERED_PIPELINES, BasePipeline, SupportedTasks -from deepsparse.benchmark import ORTEngine +from deepsparse.benchmark import ORTEngine, TorchScriptEngine from deepsparse.cpu import cpu_details from deepsparse.loggers.base_logger import BaseLogger from deepsparse.loggers.constants import MetricCategories, SystemGroups @@ -43,6 +43,7 @@ __all__ = [ "DEEPSPARSE_ENGINE", "ORT_ENGINE", + "TORCHSCRIPT_ENGINE", "SUPPORTED_PIPELINE_ENGINES", "Pipeline", "BasePipeline", @@ -62,6 +63,7 @@ DEEPSPARSE_ENGINE = "deepsparse" ORT_ENGINE = "onnxruntime" +TORCHSCRIPT_ENGINE = "torchscript" SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] @@ -151,7 +153,7 @@ def __init__( context: Optional[Context] = None, executor: Optional[Union[ThreadPoolExecutor, int]] = None, benchmark: bool = False, - _delay_engine_initialize: bool = False, + _delay_engine_initialize: bool = False, # internal use only **kwargs, ): self._benchmark = benchmark @@ -191,7 +193,6 @@ def __init__( self.engine = None else: self.engine = self._initialize_engine() - self._batch_size = self._batch_size or 1 self.log( @@ -506,7 +507,9 @@ def log_inference_times(self, timer: StagedTimer): category=MetricCategories.SYSTEM, ) - def _initialize_engine(self) -> Union[Engine, MultiModelEngine, ORTEngine]: + def _initialize_engine( + self, + ) -> Union[Engine, MultiModelEngine, ORTEngine, TorchScriptEngine]: return create_engine( self.onnx_file_path, self.engine_type, self._engine_args, self.context ) @@ -714,6 +717,9 @@ def create_engine( if engine_type == ORT_ENGINE: return ORTEngine(onnx_file_path, **engine_args) + if engine_type == TORCHSCRIPT_ENGINE: + return TorchScriptEngine(onnx_file_path, **engine_args) + raise ValueError( f"Unknown engine_type {engine_type}. Supported values include: " f"{SUPPORTED_PIPELINE_ENGINES}" diff --git a/tests/conftest.py b/tests/conftest.py index a9042a80da..e3ddf8be6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os +import shutil import tempfile from subprocess import Popen from typing import List import pytest from tests.helpers import delete_file +from tests.utils.torch import find_file_with_pattern + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +try: + import torch + + torch_import_error = None +except Exception as torch_import_err: + torch_import_error = torch_import_err + torch = None def _get_files(directory: str) -> List[str]: @@ -108,3 +123,54 @@ def check_for_created_files(): f"megabytes of temp files created in temp directory during pytest run. " f"Created files: {set(end_files_temp) - set(start_files_temp)}" ) + + +@pytest.fixture +def torchvision_fixture(): + try: + import torchvision + + return torchvision + except ImportError: + logger.error("Failed to import torchvision") + raise + + +@pytest.fixture(scope="function") +def torchvision_model_fixture(torchvision_fixture): + def get(return_jit: bool = False, **kwargs): + # [TODO]: Make a model factory if needed + torchvision_instance = torchvision_fixture + if torchvision_instance: + model = torchvision_instance.models.resnet50(kwargs) + + if return_jit: + return torch.jit.script(model) + + return model + + return get + + +@pytest.fixture(scope="function") +def torchscript_test_setup(torchvision_model_fixture): + path = os.path.expanduser(os.path.join("~/.cache/torch", "hub", "checkpoints")) + + torchvision_model_fixture(pretrained=True, return_jit=False) + expr = r"^resnet50-[0-9a-z]+\.pt[h]?$" + resnet50_nn_module_path = find_file_with_pattern(path, expr) + resnet50_nn_module = torchvision_model_fixture(pretrained=True) + + resnet50_jit = torchvision_model_fixture(pretrained=True, return_jit=True) + resnet50_jit_path = resnet50_nn_module_path.replace(".pth", ".pt") + torch.jit.save(resnet50_jit, resnet50_jit_path) + + yield { + "jit_model": resnet50_jit, + "jit_model_path": resnet50_jit_path, + "nn_module_model": resnet50_nn_module, + } + + cache_dir = os.path.expanduser("~/.cache/torch") + shutil.rmtree(cache_dir) + assert os.path.exists(cache_dir) is False diff --git a/tests/test_torchscript.py b/tests/test_torchscript.py new file mode 100644 index 0000000000..d0f72d2fba --- /dev/null +++ b/tests/test_torchscript.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy + +import pytest +from deepsparse import Pipeline +from deepsparse.benchmark.torchscript_engine import TorchScriptEngine +from deepsparse.image_classification.schemas import ImageClassificationOutput + + +try: + import torch + + torch_import_error = None +except Exception as torch_import_err: + torch_import_error = torch_import_err + torch = None + + +@pytest.mark.skipif(torch is None, reason="CUDA is not available") +def test_cpu_torchscript(torchscript_test_setup): + models = torchscript_test_setup + for model in models.values(): + engine = TorchScriptEngine(model, device="cpu") + inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + out = engine(inp) + assert isinstance(out, List) and all( + isinstance(arr, numpy.ndarray) for arr in out + ) + + +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="CUDA is not available" +) +def test_gpu_torchscript(torchscript_test_setup): + models = torchscript_test_setup + for model in models.values(): + engine = TorchScriptEngine(model, device="gpu") + inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + out = engine(inp) + assert isinstance(out, List) and all( + isinstance(arr, numpy.ndarray) for arr in out + ) + + +@pytest.mark.skipif(torch is None, reason="CUDA is not available") +def test_cpu_torchscript_pipeline(torchscript_test_setup): + models = torchscript_test_setup + + torchscript_pipeline = Pipeline.create( + task="image_classification", + model_path=models["jit_model_path"], + engine_type="torchscript", + image_size=(224, 224), + ) + + inp = [numpy.random.rand(3, 224, 224).astype(numpy.float32)] + pipeline_outputs = torchscript_pipeline(images=inp) + assert isinstance(pipeline_outputs, ImageClassificationOutput) diff --git a/tests/utils/torch.py b/tests/utils/torch.py new file mode 100644 index 0000000000..ea1b39dd29 --- /dev/null +++ b/tests/utils/torch.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Optional + +import torchvision.models as models + + +try: + import torch + + torch_import_error = None +except Exception as torch_import_err: + torch_import_error = torch_import_err + torch = None + + +def find_file_with_pattern(folder_path: str, pattern: str) -> Optional[str]: + folder_path = os.path.expanduser(folder_path) + for root, dirs, files in os.walk(folder_path): + for file in files: + if re.match(pattern, file): + return os.path.join(root, file) + + +def save_pth_to_pt(model_path_pth: str, model_name: str = "resnet50") -> None: + """ + Given .pth model path, load and save model as .pt + """ + model_func = getattr(models, model_name) + model = model_func() + model.load_state_dict( + torch.load(model_path_pth) + ) # Load the saved weights into the model + scripted_model = torch.jit.script(model) + + model_name = model_path_pth.split(".pth") + model_path_pt = f"{model_name[0]}.pt" + torch.jit.save(scripted_model, model_path_pt) + scripted_model.save(model_path_pt) + return model_path_pt