Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving changes from master_320 to master #1428

Merged
merged 6 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions documentation/source/models_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in batch format:
Expand Down Expand Up @@ -117,7 +117,7 @@ image = load_image("https://deci-pretrained-models.s3.amazonaws.com/sample_image
image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -337,10 +337,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in flat format:
Expand All @@ -359,7 +359,7 @@ Now we exported a model that produces predictions in `flat` format. Let's run th


```python
session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -437,7 +437,7 @@ export_result = model.export(
output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -471,7 +471,7 @@ export_result = model.export(
quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -514,15 +514,15 @@ export_result = model.export(
calibration_loader=dummy_calibration_loader
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})

show_predictions_from_flat_format(image, result)
```

25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.87s/it]
25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.90s/it]



Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_30_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions src/super_gradients/examples/model_export/models_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))\n",
"image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -486,7 +486,7 @@
}
],
"source": [
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -605,7 +605,7 @@
" output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -659,7 +659,7 @@
" quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -729,7 +729,7 @@
" calibration_loader=dummy_calibration_loader\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down
11 changes: 9 additions & 2 deletions src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException

__all__ = ["HasPredict", "HasPreprocessingParams", "SupportsReplaceNumClasses", "ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule"]
__all__ = [
"HasPredict",
"HasPreprocessingParams",
"SupportsReplaceNumClasses",
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
]
82 changes: 62 additions & 20 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install
from super_gradients.training.utils.export_utils import infer_format_from_file_name, infer_image_shape_from_model, infer_image_input_channels
from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed
from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules

from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules, infer_model_dtype

logger = get_logger(__name__)

__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult"]
__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult", "ModelHasNoPreprocessingParamsException"]


class ModelHasNoPreprocessingParamsException(Exception):
"""
Exception that is raised when model does not have preprocessing parameters.
"""

pass


class AbstractObjectDetectionDecodingModule(nn.Module):
Expand Down Expand Up @@ -50,6 +57,19 @@ def forward(self, predictions: Any) -> Tuple[Tensor, Tensor]:
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
This method is used to infer the total number of predictions for a given input resolution.
The function takes raw predictions from the model and returns the total number of predictions.
It is needed to check whether max_predictions_per_image and num_pre_nms_predictions are not greater than
the total number of predictions for a given resolution.

:param predictions: Predictions from the model itself.
:return: A total number of predictions for a given resolution
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

def get_output_names(self) -> List[str]:
"""
Returns the names of the outputs of the module.
Expand Down Expand Up @@ -122,7 +142,7 @@ def export(
confidence_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None,
engine: Optional[ExportTargetBackend] = None,
quantization_mode: ExportQuantizationMode = Optional[None],
quantization_mode: Optional[ExportQuantizationMode] = None,
selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
calibration_loader: Optional[DataLoader] = None,
calibration_method: str = "percentile",
Expand Down Expand Up @@ -295,7 +315,18 @@ def export(
if isinstance(preprocessing, nn.Module):
preprocessing_module = preprocessing
elif preprocessing is True:
preprocessing_module = model.get_preprocessing_callback()
try:
preprocessing_module = model.get_preprocessing_callback()
except ModelHasNoPreprocessingParamsException:
raise ValueError(
"It looks like your model does not have dataset preprocessing params properly set.\n"
"This may happen if you instantiated model from scratch and not trained it yet. \n"
"Here are what you can do to fix this:\n"
"1. Manually fill up dataset processing params via model.set_dataset_processing_params(...).\n"
"2. Train your model first and then export it. Trainer will set_dataset_processing_params(...) for you.\n"
'3. Instantiate a model using pretrained weights: models.get(..., pretrained_weights="coco") \n'
"4. Disable preprocessing by passing model.export(..., preprocessing=False). \n"
)
if isinstance(preprocessing_module, nn.Sequential):
preprocessing_module = nn.Sequential(CastTensorTo(model_type), *iter(preprocessing_module))
else:
Expand Down Expand Up @@ -325,6 +356,27 @@ def export(
num_pre_nms_predictions = postprocessing_module.num_pre_nms_predictions
max_predictions_per_image = max_predictions_per_image or num_pre_nms_predictions

dummy_input = torch.randn(input_shape).to(device=infer_model_device(model), dtype=infer_model_dtype(model))
with torch.no_grad():
number_of_predictions = postprocessing_module.infer_total_number_of_predictions(model.eval()(dummy_input))

if num_pre_nms_predictions > number_of_predictions:
logger.warning(
f"num_pre_nms_predictions ({num_pre_nms_predictions}) is greater than the total number of predictions ({number_of_predictions}) for input"
f"shape {input_shape}. Setting num_pre_nms_predictions to {number_of_predictions}"
)
num_pre_nms_predictions = number_of_predictions
# We have to re-created the postprocessing_module with the new value of num_pre_nms_predictions
postprocessing_kwargs["num_pre_nms_predictions"] = num_pre_nms_predictions
postprocessing_module: AbstractObjectDetectionDecodingModule = model.get_decoding_module(**postprocessing_kwargs)

if max_predictions_per_image > num_pre_nms_predictions:
logger.warning(
f"max_predictions_per_image ({max_predictions_per_image}) is greater than num_pre_nms_predictions ({num_pre_nms_predictions}). "
f"Setting max_predictions_per_image to {num_pre_nms_predictions}"
)
max_predictions_per_image = num_pre_nms_predictions

nms_threshold = nms_threshold or getattr(model, "_default_nms_iou", None)
if nms_threshold is None:
raise ValueError(
Expand All @@ -339,12 +391,6 @@ def export(
"Please specify the confidence_threshold explicitly: model.export(..., confidence_threshold=0.5)"
)

if max_predictions_per_image > num_pre_nms_predictions:
raise ValueError(
f"max_predictions_per_image={max_predictions_per_image} is greater than "
f"num_pre_nms_predictions={num_pre_nms_predictions}. "
f"Please specify max_predictions_per_image <= {num_pre_nms_predictions}."
)
else:
attach_nms_postprocessing = False
postprocessing_module = None
Expand Down Expand Up @@ -523,19 +569,15 @@ def export(
usage_instructions.append("")
usage_instructions.append(" import onnxruntime")
usage_instructions.append(" import numpy as np")
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}")')
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])')
usage_instructions.append(" inputs = [o.name for o in session.get_inputs()]")
usage_instructions.append(" outputs = [o.name for o in session.get_outputs()]")

dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name
if preprocessing:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
else:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
usage_instructions.append(
f" example_input_image = np.zeros(({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})" # noqa
)

usage_instructions.append(" predictions = session.run(outputs, {inputs[0]: example_input_image})")
usage_instructions.append("")

Expand Down
19 changes: 9 additions & 10 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Callable, Optional
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
from typing_extensions import Protocol, runtime_checkable

from super_gradients.training.processing.processing import Processing
if TYPE_CHECKING:
# This is a hack to avoid circular imports while still having type hints.
from super_gradients.training.processing.processing import Processing


@runtime_checkable
class HasPreprocessingParams(Protocol):
class HasPreprocessingParams:
"""
Protocol interface for torch datasets that support getting preprocessing params, later to be passed to a model
that obeys NeedsPreprocessingParams. This interface class serves a purpose of explicitly indicating whether a torch dataset has
Expand All @@ -16,7 +16,7 @@ class HasPreprocessingParams(Protocol):
"""

def get_dataset_preprocessing_params(self):
...
raise NotImplementedError(f"get_dataset_preprocessing_params is not implemented in the derived class {self.__class__.__name__}")


class HasPredict:
Expand All @@ -43,12 +43,11 @@ def get_input_channels(self) -> int:
"""
raise NotImplementedError(f"get_input_channels is not implemented in the derived class {self.__class__.__name__}")

def get_processing_params(self) -> Optional[Processing]:
def get_processing_params(self) -> Optional["Processing"]:
raise NotImplementedError(f"get_processing_params is not implemented in the derived class {self.__class__.__name__}")


@runtime_checkable
class SupportsReplaceNumClasses(Protocol):
class SupportsReplaceNumClasses:
"""
Protocol interface for modules that support replacing the number of classes.
Derived classes should implement the `replace_num_classes` method.
Expand All @@ -69,4 +68,4 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
It takes existing nn.Module and returns a new one.
:return: None
"""
...
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.utils.detection_utils import get_class_index_in_target
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat
Expand All @@ -30,7 +31,7 @@


@register_dataset(Datasets.DETECTION_DATASET)
class DetectionDataset(Dataset):
class DetectionDataset(Dataset, HasPreprocessingParams):
"""Detection dataset.

This is a boilerplate class to facilitate the implementation of datasets.
Expand Down
Loading