Skip to content

Commit

Permalink
naming
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jul 5, 2023
1 parent d97126c commit c39f061
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions src/deepsparse/benchmark/torchscript_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@

import numpy

from deepsparse.utils import (
model_to_path,
override_onnx_batch_size,
override_onnx_input_shapes,
)


try:
import torch
Expand Down Expand Up @@ -96,20 +90,16 @@ class TorchScriptEngine(object):
# Note 1: Engines are compiled for a specific batch size
# :param model: Either a path to the model's onnx file, a SparseZoo model stub
# prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File
# object that defines the neural network
# :param model: Either a path to the model's .pt file or the loaded model
# :param batch_size: The batch size of the inputs to be used with the engine
# :param num_cores: The number of physical cores to run the model on.
# :param input_shapes: The list of shapes to set the inputs to. Pass None to use model as-is.
# :param providers: The list of execution providers executing with. Pass None to use all available.
# :param device: Hardware to run the engine on. Either cpu or cuda
"""

def __init__(
self,
model: Union[str, "Model", "File"], # pt file or Module pytorch
model: Union[str, "Model"],
batch_size: int = 1,
device: str = "cpu", # or cuda
device: str = "cpu",
**kwargs,
):
if torch is None:
Expand Down Expand Up @@ -137,10 +127,10 @@ def __call__(
val_inp: bool = True,
) -> List[numpy.ndarray]:
"""
Convenience function for ORTEngine.run(), see @run for more details
Convenience function for TorchScriptEngine.run(), see @run for more details
| Example:
| engine = ORTEngine("path/to/onnx", batch_size=1, num_cores=None)
| engine = TorchScriptEngine("path/to/.pt", batch_size=1)
| inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)]
| out = engine(inp)
| assert isinstance(out, List)
Expand Down Expand Up @@ -235,13 +225,13 @@ def run(
Returns the result as a list of numpy arrays corresponding to
the outputs of the model as defined in the ONNX graph.
Note 1: the input dimensions must match what is defined in the ONNX graph.
Note 1: the input dimensions must match what is defined in the torch.nn.Module.
To avoid extra time in memory shuffles, the best use case
is to format both the onnx and the input into channels first format;
is to format both the Module and the input into channels first format;
ex: [batch, height, width, channels] => [batch, channels, height, width]
Note 2: the input type for the numpy arrays must match
what is defined in the ONNX graph.
what is defined in torch.nn.Module
Generally float32 is most common,
but int8 and int16 are used for certain layer and input types
such as with quantized models.
Expand All @@ -250,7 +240,7 @@ def run(
use numpy.ascontiguousarray(array) to fix if not.
| Example:
| engine = ORTEngine("path/to/onnx", batch_size=1, num_cores=None)
| engine = TorchScriptEngine("path/to/.pt", batch_size=1, num_cores=None)
| inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)]
| out = engine.run(inp)
| assert isinstance(out, List)
Expand Down Expand Up @@ -330,6 +320,6 @@ def _validate_inputs(self, inp: List[numpy.ndarray]):

def _properties_dict(self) -> Dict:
return {
"onnx_file_path": self.model_path,
"model_file_path": self.model_path,
"batch_size": self.batch_size,
}

0 comments on commit c39f061

Please sign in to comment.