From 834d45ab1761841ba4041eb4472f01fb63d344a6 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Tue, 22 Mar 2022 19:55:36 +0300 Subject: [PATCH] Update for nncf_task (#145) * Update for nncf_task * linter * update nncf commit * Rename CompressionCallback to OpenVINOCallback * struct * rename compression to openvino in config * rm --- README.md | 4 +- anomalib/models/__init__.py | 4 + anomalib/models/ganomaly/config.yaml | 18 +- anomalib/models/padim/config.yaml | 14 +- anomalib/models/padim/model.py | 3 +- anomalib/models/stfpm/config.yaml | 18 +- anomalib/utils/callbacks/__init__.py | 30 +-- anomalib/utils/callbacks/nncf/__init__.py | 15 ++ anomalib/utils/callbacks/nncf/callback.py | 98 ++++++++ anomalib/utils/callbacks/nncf/utils.py | 209 ++++++++++++++++++ anomalib/utils/callbacks/nncf_callback.py | 144 ------------ .../callbacks/{compress.py => openvino.py} | 2 +- requirements/base.txt | 2 +- tests/pre_merge/models/test_model_premerge.py | 1 - .../__init__.py | 0 .../dummy_config.yml | 2 +- .../dummy_lightning_model.py | 0 .../test_openvino.py} | 14 +- 18 files changed, 358 insertions(+), 220 deletions(-) create mode 100644 anomalib/utils/callbacks/nncf/__init__.py create mode 100644 anomalib/utils/callbacks/nncf/callback.py create mode 100644 anomalib/utils/callbacks/nncf/utils.py delete mode 100644 anomalib/utils/callbacks/nncf_callback.py rename anomalib/utils/callbacks/{compress.py => openvino.py} (98%) rename tests/pre_merge/utils/callbacks/{compress_callback => openvino_callback}/__init__.py (100%) rename tests/pre_merge/utils/callbacks/{compress_callback => openvino_callback}/dummy_config.yml (95%) rename tests/pre_merge/utils/callbacks/{compress_callback => openvino_callback}/dummy_lightning_model.py (100%) rename tests/pre_merge/utils/callbacks/{compress_callback/test_compress.py => openvino_callback/test_openvino.py} (68%) diff --git a/README.md b/README.md index 6f5a463658..c960c95aa3 100644 --- a/README.md +++ b/README.md @@ -126,11 +126,11 @@ python tools/inference.py \ --image_path datasets/MVTec/bottle/test/broken_large/000.png ``` -If you want to run OpenVINO model, ensure that `compression` `apply` is set to `True` in the respective model `config.yaml`. +If you want to run OpenVINO model, ensure that `openvino` `apply` is set to `True` in the respective model `config.yaml`. ```yaml optimization: - compression: + openvino: apply: true ``` diff --git a/anomalib/models/__init__.py b/anomalib/models/__init__.py index 1a8c72eee8..a450724f24 100644 --- a/anomalib/models/__init__.py +++ b/anomalib/models/__init__.py @@ -23,6 +23,10 @@ from anomalib.models.components import AnomalyModule +# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF. +# Can't not wrap `spatial_softmax2d` if use import_module. +from anomalib.models.padim.model import PadimLightning # noqa: F401 + def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule: """Load model from the configuration file. diff --git a/anomalib/models/ganomaly/config.yaml b/anomalib/models/ganomaly/config.yaml index 0fcacf309f..5d5ab2c3db 100644 --- a/anomalib/models/ganomaly/config.yaml +++ b/anomalib/models/ganomaly/config.yaml @@ -46,24 +46,8 @@ project: save_to_csv: false optimization: - compression: + openvino: apply: false - nncf: - apply: false - input_info: - sample_size: null - compression: - algorithm: quantization - initializer: - range: - num_init_samples: 256 - update_config: - init_weights: snapshot.ckpt - hyperparameter_search: - parameters: - lr: - min: 1e-4 - max: 1e-2 # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/anomalib/models/padim/config.yaml b/anomalib/models/padim/config.yaml index 86cb2d1aac..1f15c37a0e 100644 --- a/anomalib/models/padim/config.yaml +++ b/anomalib/models/padim/config.yaml @@ -39,20 +39,8 @@ project: save_to_csv: false optimization: - compression: + openvino: apply: false - nncf: - apply: false - input_info: - sample_size: [1, 3, 256, 256] - compression: - algorithm: quantization - initializer: - range: - num_init_samples: 256 - ignored_scopes: [] - update_config: - init_weights: snapshot.ckpt # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/anomalib/models/padim/model.py b/anomalib/models/padim/model.py index e35c33a215..8999545278 100644 --- a/anomalib/models/padim/model.py +++ b/anomalib/models/padim/model.py @@ -219,7 +219,8 @@ def smooth_anomaly_map(self, anomaly_map: Tensor) -> Tensor: """ kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 - anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(self.sigma, self.sigma)) + sigma = torch.as_tensor(self.sigma).to(anomaly_map.device) + anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(sigma, sigma)) return anomaly_map diff --git a/anomalib/models/stfpm/config.yaml b/anomalib/models/stfpm/config.yaml index 958d7d6a58..e19ac04cbc 100644 --- a/anomalib/models/stfpm/config.yaml +++ b/anomalib/models/stfpm/config.yaml @@ -46,24 +46,8 @@ project: save_to_csv: false optimization: - compression: + openvino: apply: false - nncf: - apply: false - input_info: - sample_size: null - compression: - algorithm: quantization - initializer: - range: - num_init_samples: 256 - update_config: - init_weights: snapshot.ckpt - hyperparameter_search: - parameters: - lr: - min: 1e-4 - max: 1e-2 # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/anomalib/utils/callbacks/__init__.py b/anomalib/utils/callbacks/__init__.py index a179930859..f31b08148c 100644 --- a/anomalib/utils/callbacks/__init__.py +++ b/anomalib/utils/callbacks/__init__.py @@ -18,19 +18,20 @@ from importlib import import_module from typing import List, Union -from omegaconf import DictConfig, ListConfig +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning.callbacks import Callback, ModelCheckpoint from .cdf_normalization import CdfNormalizationCallback -from .compress import CompressModelCallback from .min_max_normalization import MinMaxNormalizationCallback from .model_loader import LoadModelCallback +from .openvino import OpenVINOCallback from .save_to_csv import SaveToCSVCallback from .timer import TimerCallback from .visualizer_callback import VisualizerCallback __all__ = [ - "CompressModelCallback", + "OpenVINOCallback", "LoadModelCallback", "TimerCallback", "VisualizerCallback", @@ -69,10 +70,9 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none": if config.model.normalization_method == "cdf": if config.model.name in ["padim", "stfpm"]: - if not config.optimization.nncf.apply: - callbacks.append(CdfNormalizationCallback()) - else: + if "nncf" in config.optimization and config.optimization.nncf.apply: raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.") + callbacks.append(CdfNormalizationCallback()) else: raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.") elif config.model.normalization_method == "min_max": @@ -84,24 +84,24 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: callbacks.append(VisualizerCallback(inputs_are_normalized=not config.model.normalization_method == "none")) if "optimization" in config.keys(): - if config.optimization.nncf.apply: + if "nncf" in config.optimization and config.optimization.nncf.apply: # NNCF wraps torch's jit which conflicts with kornia's jit calls. # Hence, nncf is imported only when required - nncf_module = import_module("anomalib.utils.callbacks.nncf_callback") + nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") nncf_callback = getattr(nncf_module, "NNCFCallback") + nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) callbacks.append( nncf_callback( - config=config, - dirpath=os.path.join(config.project.path, "compressed"), - filename="compressed_model", + config=nncf_config, + export_dir=os.path.join(config.project.path, "compressed"), ) ) - if config.optimization.compression.apply: + if "openvino" in config.optimization and config.optimization.openvino.apply: callbacks.append( - CompressModelCallback( + OpenVINOCallback( input_size=config.model.input_size, - dirpath=os.path.join(config.project.path, "compressed"), - filename="compressed_model", + dirpath=os.path.join(config.project.path, "openvino"), + filename="openvino_model", ) ) diff --git a/anomalib/utils/callbacks/nncf/__init__.py b/anomalib/utils/callbacks/nncf/__init__.py new file mode 100644 index 0000000000..37ff514a8d --- /dev/null +++ b/anomalib/utils/callbacks/nncf/__init__.py @@ -0,0 +1,15 @@ +"""Integration NNCF.""" + +# Copyright (C) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. diff --git a/anomalib/utils/callbacks/nncf/callback.py b/anomalib/utils/callbacks/nncf/callback.py new file mode 100644 index 0000000000..3efaf7e1e0 --- /dev/null +++ b/anomalib/utils/callbacks/nncf/callback.py @@ -0,0 +1,98 @@ +"""Callbacks for NNCF optimization.""" + +# Copyright (C) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import os +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from nncf import NNCFConfig +from nncf.api.compression import CompressionAlgorithmController +from nncf.torch import register_default_init_args +from pytorch_lightning import Callback + +from anomalib.utils.callbacks.nncf.utils import InitLoader, wrap_nncf_model + + +class NNCFCallback(Callback): + """Callback for NNCF compression. + + Assumes that the pl module contains a 'model' attribute, which is + the PyTorch module that must be compressed. + + Args: + config (Dict): NNCF Configuration + export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved. + If None model will not be exported. + """ + + def __init__(self, nncf_config: Dict, export_dir: str = None): + self.export_dir = export_dir + self.nncf_config = NNCFConfig(nncf_config) + self.nncf_ctrl: Optional[CompressionAlgorithmController] = None + + # pylint: disable=unused-argument + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + """Call when fit or test begins. + + Takes the pytorch model and wraps it using the compression controller + so that it is ready for nncf fine-tuning. + """ + if self.nncf_ctrl is not None: + return + + init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore + nncf_config = register_default_init_args(self.nncf_config, init_loader) + + self.nncf_ctrl, pl_module.model = wrap_nncf_model( + model=pl_module.model, config=nncf_config, dataloader=trainer.datamodule.train_dataloader() # type: ignore + ) + + def on_train_batch_start( + self, + trainer: pl.Trainer, + _pl_module: pl.LightningModule, + _batch: Any, + _batch_idx: int, + _unused: Optional[int] = 0, + ) -> None: + """Call when the train batch begins. + + Prepare compression method to continue training the model in the next step. + """ + if self.nncf_ctrl: + self.nncf_ctrl.scheduler.step() + + def on_train_epoch_start(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None: + """Call when the train epoch starts. + + Prepare compression method to continue training the model in the next epoch. + """ + if self.nncf_ctrl: + self.nncf_ctrl.scheduler.epoch_step() + + def on_train_end(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None: + """Call when the train ends. + + Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. + """ + if self.export_dir is None or self.nncf_ctrl is None: + return + + os.makedirs(self.export_dir, exist_ok=True) + onnx_path = os.path.join(self.export_dir, "model_nncf.onnx") + self.nncf_ctrl.export_model(onnx_path) + optimize_command = "mo --input_model " + onnx_path + " --output_dir " + self.export_dir + os.system(optimize_command) diff --git a/anomalib/utils/callbacks/nncf/utils.py b/anomalib/utils/callbacks/nncf/utils.py new file mode 100644 index 0000000000..2f60fa591a --- /dev/null +++ b/anomalib/utils/callbacks/nncf/utils.py @@ -0,0 +1,209 @@ +"""Utils for NNCf optimization.""" + +# Copyright (C) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import logging +from copy import copy +from typing import Any, Dict, Iterator, List, Tuple + +from nncf import NNCFConfig +from nncf.api.compression import CompressionAlgorithmController +from nncf.torch import create_compressed_model, load_state, register_default_init_args +from nncf.torch.initialization import PTInitializingDataLoader +from nncf.torch.nncf_network import NNCFNetwork +from torch import nn +from torch.utils.data.dataloader import DataLoader + +logger = logging.getLogger(name="NNCF compression") + + +class InitLoader(PTInitializingDataLoader): + """Initializing data loader for NNCF to be used with unsupervised training algorithms.""" + + def __init__(self, data_loader: DataLoader): + super().__init__(data_loader) + self._data_loader_iter: Iterator + + def __iter__(self): + """Create iterator for dataloader.""" + self._data_loader_iter = iter(self._data_loader) + return self + + def __next__(self) -> Any: + """Return next item from dataloader iterator.""" + loaded_item = next(self._data_loader_iter) + return loaded_item["image"] + + def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]: + """Get input to model. + + Returns: + (dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during + the initialization process + """ + return (dataloader_output,), {} + + def get_target(self, _): + """Return structure for ground truth in loss criterion based on dataloader output. + + This implementation does not do anything and is a placeholder. + + Returns: + None + """ + return None + + +def wrap_nncf_model( + model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None +) -> Tuple[CompressionAlgorithmController, NNCFNetwork]: + """Wrap model by NNCF. + + :param model: Anomalib model. + :param config: NNCF config. + :param dataloader: Dataloader for initialization of NNCF model. + :param init_state_dict: Opti + :return: compression controller, compressed model + """ + nncf_config = NNCFConfig.from_dict(config) + + if not dataloader and not init_state_dict: + logger.warning( + "Either dataloader or NNCF pre-trained " + "model checkpoint should be set. Without this, " + "quantizers will not be initialized" + ) + + compression_state = None + resuming_state_dict = None + if init_state_dict: + resuming_state_dict = init_state_dict.get("model") + compression_state = init_state_dict.get("compression_state") + + if dataloader: + init_loader = InitLoader(dataloader) # type: ignore + nncf_config = register_default_init_args(nncf_config, init_loader) + + nncf_ctrl, nncf_model = create_compressed_model( + model=model, config=nncf_config, dump_graphs=False, compression_state=compression_state + ) + + if resuming_state_dict: + load_state(nncf_model, resuming_state_dict, is_resume=True) + + return nncf_ctrl, nncf_model + + +def is_state_nncf(state: Dict) -> bool: + """The function to check if sate is the result of NNCF-compressed model.""" + return bool(state.get("meta", {}).get("nncf_enable_compression", False)) + + +def compose_nncf_config(nncf_config: Dict, enabled_options: List[str]) -> Dict: + """Compose NNCf config by selected options. + + :param nncf_config: + :param enabled_options: + :return: config + """ + optimisation_parts = nncf_config + optimisation_parts_to_choose = [] + if "order_of_parts" in optimisation_parts: + # The result of applying the changes from optimisation parts + # may depend on the order of applying the changes + # (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`, + # but for sparsity it is required `total_epochs=50`) + # So, user can define `order_of_parts` in the optimisation_config + # to specify the order of applying the parts. + order_of_parts = optimisation_parts["order_of_parts"] + assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list' + + for part in enabled_options: + assert part in order_of_parts, ( + f"The part {part} is selected, " "but it is absent in order_of_parts={order_of_parts}" + ) + + optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options] + + assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part' + nncf_config_part = optimisation_parts["base"] + + for part in optimisation_parts_to_choose: + assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"' + optimisation_part_dict = optimisation_parts[part] + try: + nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict) + except AssertionError as cur_error: + err_descr = ( + f"Error during merging the parts of nncf configs:\n" + f"the current part={part}, " + f"the order of merging parts into base is {optimisation_parts_to_choose}.\n" + f"The error is:\n{cur_error}" + ) + raise RuntimeError(err_descr) from None + + return nncf_config_part + + +# pylint: disable=invalid-name +def merge_dicts_and_lists_b_into_a(a, b): + """The function to merge dict configs.""" + return _merge_dicts_and_lists_b_into_a(a, b, "") + + +def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None): + """The function is inspired by mmcf.Config._merge_a_into_b. + + * works with usual dicts and lists and derived types + * supports merging of lists (by concatenating the lists) + * makes recursive merging for dict + dict case + * overwrites when merging scalar into scalar + Note that we merge b into a (whereas Config makes merge a into b), + since otherwise the order of list merging is counter-intuitive. + """ + + def _err_str(_a, _b, _key): + if _key is None: + _key_str = "of whole structures" + else: + _key_str = f"during merging for key=`{_key}`" + return ( + f"Error in merging parts of config: different types {_key_str}," + f" type(a) = {type(_a)}," + f" type(b) = {type(_b)}" + ) + + assert isinstance(a, (dict, list)), f"Can merge only dicts and lists, whereas type(a)={type(a)}" + assert isinstance(b, (dict, list)), _err_str(a, b, cur_key) + assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key) + if isinstance(a, list): + # the main diff w.r.t. mmcf.Config -- merging of lists + return a + b + + a = copy(a) + for k in b.keys(): + if k not in a: + a[k] = copy(b[k]) + continue + new_cur_key = cur_key + "." + k if cur_key else k + if isinstance(a[k], (dict, list)): + a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key) + continue + + assert not isinstance(b[k], (dict, list)), _err_str(a[k], b[k], new_cur_key) + + # suppose here that a[k] and b[k] are scalars, just overwrite + a[k] = b[k] + return a diff --git a/anomalib/utils/callbacks/nncf_callback.py b/anomalib/utils/callbacks/nncf_callback.py deleted file mode 100644 index fba7371c4d..0000000000 --- a/anomalib/utils/callbacks/nncf_callback.py +++ /dev/null @@ -1,144 +0,0 @@ -"""NNCF Callback.""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -import os -from typing import Any, Dict, Iterator, Optional, Tuple, Union - -import pytorch_lightning as pl -import yaml -from nncf import NNCFConfig -from nncf.api.compression import CompressionAlgorithmController, CompressionScheduler -from nncf.torch import create_compressed_model, register_default_init_args -from nncf.torch.initialization import PTInitializingDataLoader -from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Callback -from torch.utils.data.dataloader import DataLoader - - -def criterion_fn(outputs, criterion): - """Calls the criterion function on outputs.""" - return criterion(outputs) - - -class InitLoader(PTInitializingDataLoader): - """Initializing data loader for NNCF to be used with unsupervised training algorithms.""" - - def __init__(self, data_loader: DataLoader): - super().__init__(data_loader) - self._data_loader_iter: Iterator - - def __iter__(self): - """Create iterator for dataloader.""" - self._data_loader_iter = iter(self._data_loader) - return self - - def __next__(self) -> Any: - """Return next item from dataloader iterator.""" - loaded_item = next(self._data_loader_iter) - return loaded_item["image"] - - def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]: - """Get input to model. - - Returns: - (dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during - the initialization process - """ - return (dataloader_output,), {} - - def get_target(self, _): - """Return structure for ground truth in loss criterion based on dataloader output. - - This implementation does not do anything and is a placeholder. - - Returns: - None - """ - return None - - -class NNCFCallback(Callback): - """Callback for NNCF compression. - - Assumes that the pl module contains a 'model' attribute, which is - the PyTorch module that must be compressed. - - Args: - config (Union[ListConfig, DictConfig]): NNCF Configuration - dirpath (str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved. - filename (str): Name of the generated model files. - """ - - def __init__(self, config: Union[ListConfig, DictConfig], dirpath: str, filename: str): - config_dict = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) - self.nncf_config = NNCFConfig.from_dict(config_dict) - self.dirpath = dirpath - self.filename = filename - - self.comp_ctrl: Optional[CompressionAlgorithmController] = None - self.compression_scheduler: CompressionScheduler - - def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: - # pylint: disable=unused-argument - """Call when fit or test begins. - - Takes the pytorch model and wraps it using the compression controller so that it is ready for nncf fine-tuning. - """ - if self.comp_ctrl is None: - # NOTE: trainer.datamodule returns the following error - # "Trainer" has no attribute "datamodule" [attr-defined] - init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore - nncf_config = register_default_init_args( - self.nncf_config, init_loader, pl_module.model.loss, criterion_fn=criterion_fn - ) - # if dump_graphs is not set to False, nncf will generate intermediate .dot files in the current dir - self.comp_ctrl, pl_module.model = create_compressed_model(pl_module.model, nncf_config, dump_graphs=False) - self.compression_scheduler = self.comp_ctrl.scheduler - - def on_train_batch_start( - self, - trainer: pl.Trainer, - _pl_module: pl.LightningModule, - _batch: Any, - _batch_idx: int, - _unused: Optional[int] = 0, - ) -> None: - """Call when the train batch begins. - - Prepare compression method to continue training the model in the next step. - """ - self.compression_scheduler.step() - if self.comp_ctrl is not None: - trainer.model.loss_val = self.comp_ctrl.loss() - - def on_train_end(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None: - """Call when the train ends. - - Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. - """ - os.makedirs(self.dirpath, exist_ok=True) - onnx_path = os.path.join(self.dirpath, self.filename + ".onnx") - if self.comp_ctrl is not None: - self.comp_ctrl.export_model(onnx_path) - optimize_command = "mo --input_model " + onnx_path + " --output_dir " + self.dirpath - os.system(optimize_command) - - def on_train_epoch_start(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None: - """Call when the train epoch starts. - - Prepare compression method to continue training the model in the next epoch. - """ - self.compression_scheduler.epoch_step() diff --git a/anomalib/utils/callbacks/compress.py b/anomalib/utils/callbacks/openvino.py similarity index 98% rename from anomalib/utils/callbacks/compress.py rename to anomalib/utils/callbacks/openvino.py index 48d9042582..b07cd7de03 100644 --- a/anomalib/utils/callbacks/compress.py +++ b/anomalib/utils/callbacks/openvino.py @@ -23,7 +23,7 @@ from anomalib.models.components import AnomalyModule -class CompressModelCallback(Callback): +class OpenVINOCallback(Callback): """Callback to compresses a trained model. Model is first exported to ``.onnx`` format, and then converted to OpenVINO IR. diff --git a/requirements/base.txt b/requirements/base.txt index 13762a497b..e13202d2a4 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,7 +7,7 @@ kornia==0.5.6 lxml==4.6.5 matplotlib==3.4.3 networkx~=2.5 -nncf==2.1.0 +nncf@ git+https://github.com/openvinotoolkit/nncf@37a830a412e60ec2fd2d84d7f00e2524e5f62777#egg=nncf numpy~=1.19.5 omegaconf==2.1.1 pillow==9.0.0 diff --git a/tests/pre_merge/models/test_model_premerge.py b/tests/pre_merge/models/test_model_premerge.py index 5af1f12510..f99adbad95 100644 --- a/tests/pre_merge/models/test_model_premerge.py +++ b/tests/pre_merge/models/test_model_premerge.py @@ -32,7 +32,6 @@ class TestModel: ("dfkde", False), ("dfm", False), ("stfpm", False), - ("stfpm", True), ("patchcore", False), ("cflow", False), ("ganomaly", False), diff --git a/tests/pre_merge/utils/callbacks/compress_callback/__init__.py b/tests/pre_merge/utils/callbacks/openvino_callback/__init__.py similarity index 100% rename from tests/pre_merge/utils/callbacks/compress_callback/__init__.py rename to tests/pre_merge/utils/callbacks/openvino_callback/__init__.py diff --git a/tests/pre_merge/utils/callbacks/compress_callback/dummy_config.yml b/tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml similarity index 95% rename from tests/pre_merge/utils/callbacks/compress_callback/dummy_config.yml rename to tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml index 8939e6cb4b..cec6be1412 100644 --- a/tests/pre_merge/utils/callbacks/compress_callback/dummy_config.yml +++ b/tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml @@ -18,7 +18,7 @@ project: path: ./results optimization: - compression: + openvino: apply: true trainer: diff --git a/tests/pre_merge/utils/callbacks/compress_callback/dummy_lightning_model.py b/tests/pre_merge/utils/callbacks/openvino_callback/dummy_lightning_model.py similarity index 100% rename from tests/pre_merge/utils/callbacks/compress_callback/dummy_lightning_model.py rename to tests/pre_merge/utils/callbacks/openvino_callback/dummy_lightning_model.py diff --git a/tests/pre_merge/utils/callbacks/compress_callback/test_compress.py b/tests/pre_merge/utils/callbacks/openvino_callback/test_openvino.py similarity index 68% rename from tests/pre_merge/utils/callbacks/compress_callback/test_compress.py rename to tests/pre_merge/utils/callbacks/openvino_callback/test_openvino.py index a18d61b440..aca4bff357 100644 --- a/tests/pre_merge/utils/callbacks/compress_callback/test_compress.py +++ b/tests/pre_merge/utils/callbacks/openvino_callback/test_openvino.py @@ -5,26 +5,26 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from anomalib.config import get_configurable_parameters -from anomalib.utils.callbacks import CompressModelCallback -from tests.pre_merge.utils.callbacks.compress_callback.dummy_lightning_model import ( +from anomalib.utils.callbacks import OpenVINOCallback +from tests.pre_merge.utils.callbacks.openvino_callback.dummy_lightning_model import ( DummyLightningModule, FakeDataModule, ) -def test_compress_model_callback(): +def test_openvino_model_callback(): """Tests if an optimized model is created.""" config = get_configurable_parameters( - model_config_path="tests/pre_merge/utils/callbacks/compress_callback/dummy_config.yml" + model_config_path="tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml" ) with tempfile.TemporaryDirectory() as tmp_dir: config.project.path = tmp_dir model = DummyLightningModule(hparams=config) model.callbacks = [ - CompressModelCallback( - input_size=config.model.input_size, dirpath=os.path.join(tmp_dir), filename="compressed_model" + OpenVINOCallback( + input_size=config.model.input_size, dirpath=os.path.join(tmp_dir), filename="openvino_model" ), EarlyStopping(monitor=config.model.metric), ] @@ -39,4 +39,4 @@ def test_compress_model_callback(): ) trainer.fit(model, datamodule=datamodule) - assert os.path.exists(os.path.join(tmp_dir, "compressed_model.bin")), "Failed to generate OpenVINO model" + assert os.path.exists(os.path.join(tmp_dir, "openvino_model.bin")), "Failed to generate OpenVINO model"