Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to load metrics with kwargs #688

6 changes: 2 additions & 4 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
callbacks.append(post_processing_callback)

# Add metric configuration to the model via MetricsConfigurationCallback
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
metrics_callback = MetricsConfigurationCallback(
config.dataset.task,
image_metric_names,
pixel_metric_names,
config.metrics.get("image", None),
config.metrics.get("pixel", None),
)
callbacks.append(metrics_callback)

Expand Down
6 changes: 3 additions & 3 deletions anomalib/utils/callbacks/metrics_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.models.components.base.anomaly_module import AnomalyModule
from anomalib.utils.metrics import metric_collection_from_names
from anomalib.utils.metrics import create_metric_collection

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,8 +74,8 @@ def setup(
pixel_metric_names = self.pixel_metric_names

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = metric_collection_from_names(image_metric_names, "image_")
pl_module.pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")

pl_module.image_metrics.set_threshold(pl_module.image_threshold.value)
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value)
141 changes: 123 additions & 18 deletions anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import importlib
import warnings
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import torchmetrics
from omegaconf import DictConfig, ListConfig
Expand All @@ -23,23 +23,6 @@
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AnomalyScoreThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
"""Create metric collections based on the config.

Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf

Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics


def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names.

Expand Down Expand Up @@ -68,3 +51,125 @@ def metric_collection_from_names(metric_names: List[str], prefix: Optional[str])
else:
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
return metrics


def _validate_metrics_dict(metrics: Dict[str, Dict[str, Any]]) -> None:
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
"""Check the assumptions about metrics config dict.

- Keys are metric names
- Values are dictionaries.
- Internal dictionaries:
- have key "class_path" and its value is of type str
- have key init_args" and its value is of type dict).

"""
assert all(
isinstance(metric, str) for metric in metrics.keys()
), f"All keys (metric names) must be strings, found {sorted(metrics.keys())}"
assert all(
isinstance(metric, (dict, DictConfig)) for metric in metrics.values()
), f"All values must be dictionaries, found {list(metrics.values())}"
assert all("class_path" in metric and isinstance(metric["class_path"], str) for metric in metrics.values()), (
"All internal dictionaries must have a 'class_path' key whose value is of type str, "
f"found {list(metrics.values())}"
)
assert all(
"init_args" in metric and isinstance(metric["init_args"], (dict, DictConfig)) for metric in metrics.values()
), (
"All internal dictionaries must have a 'init_args' key whose value is of type dict, "
f"found {list(metrics.values())}"
)


def _get_class_from_path(class_path: str) -> Any:
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
"""Get a class from a module assuming the string format is `package.subpackage.module.ClassName`."""
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
assert hasattr(module, class_name), f"Class {class_name} not found in module {module_name}"
cls = getattr(module, class_name)
return cls


def metric_collection_from_dicts(metrics: Dict[str, Dict[str, Any]], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a dict of "metric name" -> "metric specifications".

jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
Example:

metrics = {
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
"PixelWiseF1Score": {
"class_path": "torchmetrics.F1Score",
"init_args": {},
},
"PixelWiseAUROC": {
"class_path": "anomalib.utils.metrics.AUROC",
"init_args": {
"compute_on_cpu": True,
},
},
}

In the config file, the same specifications (for pixel-wise metrics) look like:

```yaml
metrics:
pixel:
PixelWiseF1Score:
class_path: torchmetrics.F1Score
init_args: {}
PixelWiseAUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
```

Args:
metrics (Dict[str, Dict[str, Any]]): keys are metric names, values are dictionaries.
Internal Dict[str, Any] keys are "class_path" (value is string) and "init_args" (value is dict),
following the convention in Pytorch Lightning CLI.

prefix (Optional[str]): prefix to assign to the metrics in the collection.

Returns:
AnomalibMetricCollection: Collection of metrics.
"""
_validate_metrics_dict(metrics)
metrics_collection = {}
for name, dict_ in metrics.items():
class_path = dict_["class_path"]
kwargs = dict_["init_args"]
cls = _get_class_from_path(class_path)
metrics_collection[name] = cls(**kwargs)
return AnomalibMetricCollection(metrics_collection, prefix=prefix)


def create_metric_collection(
metrics: Union[List[str], Dict[str, Dict[str, Any]]], prefix: Optional[str]
) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names or dictionaries.

This function will dispatch the actual creation to the appropriate function depending on the input type:

- if List[str] (names of metrics): see `metric_collection_from_names`
- if Dict[str, Dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts`

The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.

Args:
metrics (Union[List[str], Dict[str, Dict[str, Any]]]).
prefix (Optional[str]): prefix to assign to the metrics in the collection.

Returns:
AnomalibMetricCollection: Collection of metrics.
"""
# fallback is using the names

if isinstance(metrics, (ListConfig, list)):
assert all(isinstance(metric, str) for metric in metrics), f"All metrics must be strings, found {metrics}"
return metric_collection_from_names(metrics, prefix)

if isinstance(metrics, (DictConfig, dict)):
_validate_metrics_dict(metrics)
return metric_collection_from_dicts(metrics, prefix)

raise ValueError(f"metrics must be a list or a dict, found {type(metrics)}")
40 changes: 40 additions & 0 deletions docs/source/reference_guide/api/metrics.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
Metrics
=======

There are two ways of configuring metrics in the config file:

1. a list of metric names, or
2. a mapping of metric names to class path and init args.

Each subsection in the section ``metrics`` of the config file can have a different style but inside each one it must be the same style.

.. code-block:: yaml
:caption: Example of metrics configuration section in the config file.

metrics:
# imagewise metrics using the list of metric names style
image:
- F1Score
- AUROC
# pixelwise metrics using the mapping style
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true

List of metric names
--------------------

A list of strings that match the name of a class in ``anomalib.utils.metrics`` or ``torchmetrics`` (in this order of priority), which will be instantiated with default arguments.

Mapping of metric names to class path and init args
---------------------------------------------------

A mapping of metric names (str) to a dictionary with two keys: "class_path" and "init_args".

"class_path" is a string with the full path to a metric (from root package down to the class name, e.g.: "anomalib.utils.metrics.AUROC").

"init_args" is a dictionary of arguments to be passed to the class constructor.

.. automodule:: anomalib.utils.metrics
:members:
:undoc-members:
Expand Down
45 changes: 45 additions & 0 deletions tests/helpers/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import shutil
import tempfile
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from anomalib.utils.loggers.tensorboard import AnomalibTensorBoardLogger


class DummyDataset(Dataset):
def __len__(self):
return 1

def __getitem__(self, idx):
return torch.ones(1)


class DummyDataModule(pl.LightningDataModule):
def train_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())

def val_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())

def test_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())


class DummyModel(nn.Module):
def __init__(self):
# pytlint: disable=useless-parent-delegation
super().__init__()
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved


class DummyLogger(AnomalibTensorBoardLogger):
def __init__(self):
self.tempdir = Path(tempfile.mkdtemp())
super().__init__(name="tensorboard_logs", save_dir=self.tempdir)

def __del__(self):
if self.tempdir.exists():
shutil.rmtree(self.tempdir)
27 changes: 27 additions & 0 deletions tests/helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Helpers for metrics tests."""

from typing import Tuple, Union

from omegaconf import DictConfig, ListConfig

from anomalib.utils.metrics import (
AnomalibMetricCollection,
metric_collection_from_names,
)


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
"""Create metric collections based on the config.

Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf

Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
image:
- F1Score
- AUROC
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
pixel:
- F1Score
- AUROC
image:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
Loading