diff --git a/anomalib/deploy/optimize.py b/anomalib/deploy/optimize.py index 721e754360..a239d5b176 100644 --- a/anomalib/deploy/optimize.py +++ b/anomalib/deploy/optimize.py @@ -11,6 +11,7 @@ import numpy as np import torch from torch import Tensor +from torch.types import Number from anomalib.models.components import AnomalyModule @@ -25,16 +26,13 @@ def get_model_metadata(model: AnomalyModule) -> Dict[str, Tensor]: Dict[str, Tensor]: metadata """ meta_data = {} - cached_meta_data = { + cached_meta_data: Dict[str, Union[Number, Tensor]] = { "image_threshold": model.image_threshold.cpu().value.item(), "pixel_threshold": model.pixel_threshold.cpu().value.item(), - "pixel_mean": model.training_distribution.pixel_mean.cpu(), - "image_mean": model.training_distribution.image_mean.cpu(), - "pixel_std": model.training_distribution.pixel_std.cpu(), - "image_std": model.training_distribution.image_std.cpu(), - "min": model.min_max.min.cpu().item(), - "max": model.min_max.max.cpu().item(), } + if hasattr(model, "normalization_metrics") and model.normalization_metrics.state_dict() is not None: + for key, value in model.normalization_metrics.state_dict().items(): + cached_meta_data[key] = value.cpu() # Remove undefined values by copying in a new dict for key, val in cached_meta_data.items(): if not np.isinf(val).all(): diff --git a/anomalib/models/components/base/anomaly_module.py b/anomalib/models/components/base/anomaly_module.py index 6e6dc2936a..e3ff975693 100644 --- a/anomalib/models/components/base/anomaly_module.py +++ b/anomalib/models/components/base/anomaly_module.py @@ -10,13 +10,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from torch import Tensor, nn +from torchmetrics import Metric -from anomalib.utils.metrics import ( - AdaptiveThreshold, - AnomalibMetricCollection, - AnomalyScoreDistribution, - MinMax, -) +from anomalib.utils.metrics import AdaptiveThreshold, AnomalibMetricCollection logger = logging.getLogger(__name__) @@ -41,12 +37,8 @@ def __init__(self): self.image_threshold = AdaptiveThreshold().cpu() self.pixel_threshold = AdaptiveThreshold().cpu() - self.training_distribution = AnomalyScoreDistribution().cpu() - self.min_max = MinMax().cpu() + self.normalization_metrics: Metric - # Create placeholders for image and pixel metrics. - # If set from the config file, MetricsConfigurationCallback will - # create the metric collections upon setup. self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection diff --git a/anomalib/utils/callbacks/cdf_normalization.py b/anomalib/utils/callbacks/cdf_normalization.py index e4145accd9..5b426fc8a5 100644 --- a/anomalib/utils/callbacks/cdf_normalization.py +++ b/anomalib/utils/callbacks/cdf_normalization.py @@ -15,6 +15,7 @@ from anomalib.models import get_model from anomalib.models.components import AnomalyModule from anomalib.post_processing.normalization.cdf import normalize, standardize +from anomalib.utils.metrics import AnomalyScoreDistribution logger = logging.getLogger(__name__) @@ -27,7 +28,19 @@ def __init__(self): self.image_dist: Optional[LogNormal] = None self.pixel_dist: Optional[LogNormal] = None - def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None: + # pylint: disable=unused-argument + def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: Optional[str] = None) -> None: + """Adds training_distribution metrics to normalization metrics.""" + if not hasattr(pl_module, "normalization_metrics"): + pl_module.normalization_metrics = AnomalyScoreDistribution().cpu() + elif not isinstance(pl_module.normalization_metrics, AnomalyScoreDistribution): + raise AttributeError( + f"Expected normalization_metrics to be of type AnomalyScoreDistribution," + f" got {type(pl_module.normalization_metrics)}" + ) + + # pylint: disable=unused-argument + def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None: """Called when the test begins.""" if pl_module.image_metrics is not None: pl_module.image_metrics.set_threshold(0.5) @@ -93,24 +106,25 @@ def _collect_stats(self, trainer, pl_module): predictions = Trainer(gpus=trainer.gpus).predict( model=self._create_inference_model(pl_module), dataloaders=trainer.datamodule.train_dataloader() ) - pl_module.training_distribution.reset() + pl_module.normalization_metrics.reset() for batch in predictions: if "pred_scores" in batch.keys(): - pl_module.training_distribution.update(anomaly_scores=batch["pred_scores"]) + pl_module.normalization_metrics.update(anomaly_scores=batch["pred_scores"]) if "anomaly_maps" in batch.keys(): - pl_module.training_distribution.update(anomaly_maps=batch["anomaly_maps"]) - pl_module.training_distribution.compute() + pl_module.normalization_metrics.update(anomaly_maps=batch["anomaly_maps"]) + pl_module.normalization_metrics.compute() @staticmethod def _create_inference_model(pl_module): """Create a duplicate of the PL module that can be used to perform inference on the training set.""" new_model = get_model(pl_module.hparams) + new_model.normalization_metrics = AnomalyScoreDistribution().cpu() new_model.load_state_dict(pl_module.state_dict()) return new_model @staticmethod def _standardize_batch(outputs: STEP_OUTPUT, pl_module) -> None: - stats = pl_module.training_distribution.to(outputs["pred_scores"].device) + stats = pl_module.normalization_metrics.to(outputs["pred_scores"].device) outputs["pred_scores"] = standardize(outputs["pred_scores"], stats.image_mean, stats.image_std) if "anomaly_maps" in outputs.keys(): outputs["anomaly_maps"] = standardize( diff --git a/anomalib/utils/callbacks/min_max_normalization.py b/anomalib/utils/callbacks/min_max_normalization.py index b51a27e9a5..c6dea697e8 100644 --- a/anomalib/utils/callbacks/min_max_normalization.py +++ b/anomalib/utils/callbacks/min_max_normalization.py @@ -3,7 +3,7 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +from typing import Any, Dict, Optional import pytorch_lightning as pl from pytorch_lightning import Callback @@ -12,18 +12,29 @@ from anomalib.models.components import AnomalyModule from anomalib.post_processing.normalization.min_max import normalize +from anomalib.utils.metrics import MinMax @CALLBACK_REGISTRY class MinMaxNormalizationCallback(Callback): """Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization.""" - def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None: + # pylint: disable=unused-argument + def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: Optional[str] = None) -> None: + """Adds min_max metrics to normalization metrics.""" + if not hasattr(pl_module, "normalization_metrics"): + pl_module.normalization_metrics = MinMax().cpu() + elif not isinstance(pl_module.normalization_metrics, MinMax): + raise AttributeError( + f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}" + ) + + # pylint: disable=unused-argument + def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None: """Called when the test begins.""" - if pl_module.image_metrics is not None: - pl_module.image_metrics.set_threshold(0.5) - if pl_module.pixel_metrics is not None: - pl_module.pixel_metrics.set_threshold(0.5) + for metric in (pl_module.image_metrics, pl_module.pixel_metrics): + if metric is not None: + metric.set_threshold(0.5) def on_validation_batch_end( self, @@ -36,9 +47,9 @@ def on_validation_batch_end( ) -> None: """Called when the validation batch ends, update the min and max observed values.""" if "anomaly_maps" in outputs.keys(): - pl_module.min_max(outputs["anomaly_maps"]) + pl_module.normalization_metrics(outputs["anomaly_maps"]) else: - pl_module.min_max(outputs["pred_scores"]) + pl_module.normalization_metrics(outputs["pred_scores"]) def on_test_batch_end( self, @@ -67,7 +78,7 @@ def on_predict_batch_end( @staticmethod def _normalize_batch(outputs, pl_module): """Normalize a batch of predictions.""" - stats = pl_module.min_max.cpu() + stats = pl_module.normalization_metrics.cpu() outputs["pred_scores"] = normalize( outputs["pred_scores"], pl_module.image_threshold.value.cpu(), stats.min, stats.max ) diff --git a/anomalib/utils/sweep/__init__.py b/anomalib/utils/sweep/__init__.py index 12dc59d939..36493a3988 100644 --- a/anomalib/utils/sweep/__init__.py +++ b/anomalib/utils/sweep/__init__.py @@ -4,18 +4,12 @@ # SPDX-License-Identifier: Apache-2.0 from .config import flatten_sweep_params, get_run_config, set_in_nested_config -from .helpers import ( - get_meta_data, - get_openvino_throughput, - get_sweep_callbacks, - get_torch_throughput, -) +from .helpers import get_openvino_throughput, get_sweep_callbacks, get_torch_throughput __all__ = [ "get_run_config", "set_in_nested_config", "get_sweep_callbacks", - "get_meta_data", "get_openvino_throughput", "get_torch_throughput", "flatten_sweep_params", diff --git a/anomalib/utils/sweep/helpers/__init__.py b/anomalib/utils/sweep/helpers/__init__.py index b970b59be7..27b610427d 100644 --- a/anomalib/utils/sweep/helpers/__init__.py +++ b/anomalib/utils/sweep/helpers/__init__.py @@ -4,6 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from .callbacks import get_sweep_callbacks -from .inference import get_meta_data, get_openvino_throughput, get_torch_throughput +from .inference import get_openvino_throughput, get_torch_throughput -__all__ = ["get_meta_data", "get_openvino_throughput", "get_torch_throughput", "get_sweep_callbacks"] +__all__ = ["get_openvino_throughput", "get_torch_throughput", "get_sweep_callbacks"] diff --git a/anomalib/utils/sweep/helpers/inference.py b/anomalib/utils/sweep/helpers/inference.py index 5a6e84c6c8..fbbdfed2d0 100644 --- a/anomalib/utils/sweep/helpers/inference.py +++ b/anomalib/utils/sweep/helpers/inference.py @@ -5,7 +5,7 @@ import time from pathlib import Path -from typing import Dict, Iterable, List, Tuple, Union +from typing import Iterable, List, Union import numpy as np import torch @@ -45,45 +45,8 @@ def __call__(self) -> Iterable[np.ndarray]: yield self.image -def get_meta_data(model: AnomalyModule, input_size: Tuple[int, int]) -> Dict: - """Get meta data for inference. - - Args: - model (AnomalyModule): Trained model from which the metadata is extracted. - input_size (Tuple[int, int]): Input size used to resize the pixel level mean and std. - - Returns: - (Dict): Metadata as dictionary. - """ - meta_data = { - "image_threshold": model.image_threshold.value.cpu().numpy(), - "pixel_threshold": model.pixel_threshold.value.cpu().numpy(), - "min": model.min_max.min.cpu().numpy(), - "max": model.min_max.max.cpu().numpy(), - "stats": {}, - } - - image_mean = model.training_distribution.image_mean.cpu().numpy() - if image_mean.size > 0: - meta_data["stats"]["image_mean"] = image_mean - - image_std = model.training_distribution.image_std.cpu().numpy() - if image_std.size > 0: - meta_data["stats"]["image_std"] = image_std - - pixel_mean = model.training_distribution.pixel_mean.cpu().numpy() - if pixel_mean.size > 0: - meta_data["stats"]["pixel_mean"] = pixel_mean.reshape(input_size) - - pixel_std = model.training_distribution.pixel_std.cpu().numpy() - if pixel_std.size > 0: - meta_data["stats"]["pixel_std"] = pixel_std.reshape(input_size) - - return meta_data - - def get_torch_throughput( - config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader, meta_data: Dict + config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader ) -> float: """Tests the model on dummy data. Images are passed sequentially to make the comparision with OpenVINO model fair. @@ -91,7 +54,6 @@ def get_torch_throughput( config (Union[DictConfig, ListConfig]): Model config. model (Path): Model on which inference is called. test_dataset (DataLoader): The test dataset used as a reference for the mock dataset. - meta_data (Dict): Metadata used for normalization. Returns: float: Inference throughput @@ -103,7 +65,7 @@ def get_torch_throughput( start_time = time.time() # Since we don't care about performance metrics and just the throughput, use mock data. for image in torch_dataloader(): - inferencer.predict(image, meta_data=meta_data) + inferencer.predict(image) # get throughput inference_time = time.time() - start_time @@ -113,26 +75,23 @@ def get_torch_throughput( return throughput -def get_openvino_throughput( - config: Union[DictConfig, ListConfig], model_path: Path, test_dataset: DataLoader, meta_data: Dict -) -> float: +def get_openvino_throughput(config: Union[DictConfig, ListConfig], model_path: Path, test_dataset: DataLoader) -> float: """Runs the generated OpenVINO model on a dummy dataset to get throughput. Args: config (Union[DictConfig, ListConfig]): Model config. model_path (Path): Path to folder containing the OpenVINO models. It then searches `model.xml` in the folder. test_dataset (DataLoader): The test dataset used as a reference for the mock dataset. - meta_data (Dict): Metadata used for normalization. Returns: float: Inference throughput """ - inferencer = OpenVINOInferencer(config, model_path / "model.xml") + inferencer = OpenVINOInferencer(config, model_path / "model.xml", model_path / "meta_data.json") openvino_dataloader = MockImageLoader(config.dataset.image_size, total_count=len(test_dataset)) start_time = time.time() # Create test images on CPU. Since we don't care about performance metrics and just the throughput, use mock data. for image in openvino_dataloader(): - inferencer.predict(image, meta_data=meta_data) + inferencer.predict(image) # get throughput inference_time = time.time() - start_time diff --git a/tests/helpers/inference.py b/tests/helpers/inference.py index 1062ef514b..4c36f43ee3 100644 --- a/tests/helpers/inference.py +++ b/tests/helpers/inference.py @@ -4,12 +4,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Iterable, List, Tuple +from typing import Iterable, List import numpy as np -from anomalib.models.components import AnomalyModule - class MockImageLoader: """Create mock images for inference on CPU based on the specifics of the original torch test dataset. @@ -35,36 +33,3 @@ def __call__(self) -> Iterable[np.ndarray]: """ for _ in range(self.total_count): yield self.image - - -def get_meta_data(model: AnomalyModule, input_size: Tuple[int, int]) -> Dict: - """Get meta data for inference. - Args: - model (AnomalyModule): Trained model from which the metadata is extracted. - input_size (Tuple[int, int]): Input size used to resize the pixel level mean and std. - Returns: - (Dict): Metadata as dictionary. - """ - meta_data = { - "image_threshold": model.image_threshold.value.cpu().numpy(), - "pixel_threshold": model.pixel_threshold.value.cpu().numpy(), - "stats": {}, - } - - image_mean = model.training_distribution.image_mean.cpu().numpy() - if image_mean.size > 0: - meta_data["stats"]["image_mean"] = image_mean - - image_std = model.training_distribution.image_std.cpu().numpy() - if image_std.size > 0: - meta_data["stats"]["image_std"] = image_std - - pixel_mean = model.training_distribution.pixel_mean.cpu().numpy() - if pixel_mean.size > 0: - meta_data["stats"]["pixel_mean"] = pixel_mean.reshape(input_size) - - pixel_std = model.training_distribution.pixel_std.cpu().numpy() - if pixel_std.size > 0: - meta_data["stats"]["pixel_std"] = pixel_std.reshape(input_size) - - return meta_data diff --git a/tests/pre_merge/deploy/test_inferencer.py b/tests/pre_merge/deploy/test_inferencer.py index b33027afc6..6f51f2a821 100644 --- a/tests/pre_merge/deploy/test_inferencer.py +++ b/tests/pre_merge/deploy/test_inferencer.py @@ -18,7 +18,7 @@ from anomalib.models import get_model from anomalib.utils.callbacks import get_callbacks from tests.helpers.dataset import TestDataset, get_dataset_path -from tests.helpers.inference import MockImageLoader, get_meta_data +from tests.helpers.inference import MockImageLoader def get_model_config( @@ -74,10 +74,9 @@ def test_torch_inference(self, model_name: str, category: str = "shapes", path: # Test torch inferencer torch_inferencer = TorchInferencer(model_config, model) torch_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1) - meta_data = get_meta_data(model, model_config.dataset.image_size) with torch.no_grad(): for image in torch_dataloader(): - torch_inferencer.predict(image, meta_data=meta_data) + torch_inferencer.predict(image) @pytest.mark.parametrize( "model_name", @@ -111,8 +110,9 @@ def test_openvino_inference(self, model_name: str, category: str = "shapes", pat ) # Test OpenVINO inferencer - openvino_inferencer = OpenVINOInferencer(model_config, export_path / "model.xml") + openvino_inferencer = OpenVINOInferencer( + model_config, export_path / "model.xml", export_path / "meta_data.json" + ) openvino_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1) - meta_data = get_meta_data(model, model_config.dataset.image_size) for image in openvino_dataloader(): - openvino_inferencer.predict(image, meta_data=meta_data) + openvino_inferencer.predict(image) diff --git a/tools/benchmarking/benchmark.py b/tools/benchmarking/benchmark.py index 9f2e7c5869..67a1ccaac5 100644 --- a/tools/benchmarking/benchmark.py +++ b/tools/benchmarking/benchmark.py @@ -28,7 +28,6 @@ from anomalib.models import get_model from anomalib.utils.loggers import configure_logger from anomalib.utils.sweep import ( - get_meta_data, get_openvino_throughput, get_run_config, get_sweep_callbacks, @@ -108,9 +107,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi # get testing time testing_time = time.time() - start_time - meta_data = get_meta_data(model, model_config.model.input_size) - - throughput = get_torch_throughput(model_config, model, datamodule.test_dataloader().dataset, meta_data) + throughput = get_torch_throughput(model_config, model, datamodule.test_dataloader().dataset) # Get OpenVINO metrics openvino_throughput = float("nan") @@ -120,7 +117,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi openvino_export_path.mkdir(parents=True, exist_ok=True) convert_to_openvino(model, openvino_export_path, model_config.model.input_size) openvino_throughput = get_openvino_throughput( - model_config, openvino_export_path, datamodule.test_dataloader().dataset, meta_data + model_config, openvino_export_path, datamodule.test_dataloader().dataset ) # arrange the data