Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binning capability to AUPRO #1145

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
41ea64a
Add the capability to compute binned AUPRO.
Yann-CV Jun 20, 2023
739ae0e
fix linting
Yann-CV Jun 20, 2023
4d85d32
use directly binary_roc
Yann-CV Jun 28, 2023
b2c547c
update CHANGELOG.md
Yann-CV Jun 28, 2023
8fa16bc
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Jun 29, 2023
e6fe63e
improve test by doing 2 different ones (aupro and binned aupro) + ren…
Yann-CV Jun 29, 2023
6781574
Merge remote-tracking branch 'origin/aupro/binning-from-thresholds' i…
Yann-CV Jun 29, 2023
9ee9c55
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Jun 30, 2023
ec8043f
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Jul 5, 2023
18a4abb
only allow num_thresholds as input + fix tests + add threshold comput…
Yann-CV Jul 5, 2023
a382be6
add binning tests
Yann-CV Jul 5, 2023
63a2aca
use binary roc directly
Yann-CV Jul 6, 2023
a0d6fbd
remove unused import and rename some
Yann-CV Jul 6, 2023
553380a
device for thresholds
Yann-CV Jul 6, 2023
4d4c6c3
Merge branch 'main' into aupro/binning-from-thresholds
samet-akcay Jul 7, 2023
f4fd62a
fix linting
Yann-CV Jul 7, 2023
650e4e8
use torch.all
Yann-CV Jul 7, 2023
b82bedf
fix linting
Yann-CV Jul 7, 2023
c08761d
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Jul 7, 2023
893579d
fix linting
Yann-CV Jul 10, 2023
1a87971
Merge remote-tracking branch 'origin/aupro/binning-from-thresholds' i…
Yann-CV Jul 10, 2023
6feeb0b
remove observe time
Yann-CV Jul 10, 2023
fb68d7b
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Jul 24, 2023
16a8732
Merge branch 'main' into aupro/binning-from-thresholds
samet-akcay Aug 16, 2023
0a0e31d
Merge branch 'main' into aupro/binning-from-thresholds
Yann-CV Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- AUPRO binning capability by @yann-cv

### Changed

- Add default values to algorithms, based on the original papers.
Expand Down
41 changes: 38 additions & 3 deletions src/anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import auc, roc
from torchmetrics.functional import auc
from torchmetrics.functional.classification import binary_roc
from torchmetrics.utilities.data import dim_zero_cat

from anomalib.utils.metrics.pro import (
connected_components_cpu,
connected_components_gpu,
)

from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max
from .plotting_utils import plot_figure


Expand All @@ -30,6 +32,13 @@ class AUPRO(Metric):
full_state_update: bool = False
preds: list[Tensor]
target: list[Tensor]
# When not None, the computation is performed in constant-memory by computing the roc curve
# for fixed thresholds buckets/thresholds.
# Warning: The thresholds are evenly distributed between the min and max predictions
# if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1.
Yann-CV marked this conversation as resolved.
Show resolved Hide resolved
# This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
# and the roc curve is computed with deactivated formatting
num_thresholds: int | None

def __init__(
self,
Expand All @@ -38,6 +47,7 @@ def __init__(
process_group: Any | None = None,
dist_sync_fn: Callable | None = None,
fpr_limit: float = 0.3,
num_thresholds: int | None = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -49,6 +59,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))
self.num_thresholds = num_thresholds

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with new values.
Expand Down Expand Up @@ -96,9 +107,29 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso
Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""
if self.num_thresholds is not None:
# binary_roc is applying a sigmoid on the predictions before computing the roc curve
# when some predictions are out of [0, 1], the binning between min and max predictions
# cannot be applied in that case. This can be removed when
# https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and
# the roc curve is computed with deactivated formatting.

if torch.all((0 <= preds) * (preds <= 1)):
thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device)
else:
thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device)

else:
thresholds = None

# compute the global fpr-size
fpr: Tensor = roc(preds, target)[0] # only need fpr
fpr: Tensor = binary_roc(
preds=preds,
target=target,
thresholds=thresholds,
)[
0
] # only need fpr
Comment on lines +126 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the roc function here instead? The reason is that binary_roc maps the predictions to the [0, 1] range using sigmoid, which is exactly what we're trying to avoid. We asked the TorchMetrics developers to make the sigmoid mapping optional, but until then it would be better if we use the legacy roc function, which does not remap the predictions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djdameln in fact roc() will just call binary_roc() (the former is just a wrapper for the argument task: str

https://github.com/Lightning-AI/torchmetrics/blob/2a055f5594a624685e26ba64bf20ab0d12225c86/src/torchmetrics/functional/classification/roc.py#L595

    if task is not None:
        if task == "binary":
            return binary_roc(preds, target, thresholds, ignore_index, validate_args)
        if task == "multiclass":
            assert isinstance(num_classes, int)
            return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args)
        if task == "multilabel":
            assert isinstance(num_labels, int)
            return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
        raise ValueError(
            f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
        )
    else:
        rank_zero_warn(
            "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification"
            " metric. Moving forward we recommend using these versions. This base metric will still work as it did"
            " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required"
            " and the general order of arguments may change, such that this metric will just function as an single"
            " entrypoint to calling the three specialized versions.",
            DeprecationWarning,
        )
    preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label)
    return _roc_compute(preds, target, num_classes, pos_label, sample_weights)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, the roc call with thresholds :

  • is requiring the task to be defined. Otherwise the thresholds will just be silently ignored.
  • is automatically calling binary roc at the end.

This roc function will be soon deprecated so I would suggest to keep it like this.

Also @djdameln we added a comment just above to mention this issue.

output_size = torch.where(fpr <= self.fpr_limit)[0].size(0)

# compute the PRO curve by aggregating per-region tpr/fpr curves/values.
Expand All @@ -120,7 +151,11 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso
mask = cca == label
# Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other
# label in labels as FPs. We also don't need to return the thresholds
_fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1]
_fpr, _tpr = binary_roc(
preds=preds[background | mask],
target=mask[background | mask],
thresholds=thresholds,
)[:-1]

# catch edge-case where ROC only has fpr vals > self.fpr_limit
if _fpr[_fpr <= self.fpr_limit].max() == 0:
Expand Down
14 changes: 14 additions & 0 deletions src/anomalib/utils/metrics/binning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional

from torch import Tensor, linspace
from torch import device as torch_device


def thresholds_between_min_and_max(
preds: Tensor, num_thresholds: int = 100, device: Optional[torch_device] = None
) -> Tensor:
return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device)


def thresholds_between_0_and_1(num_thresholds: int = 100, device: Optional[torch_device] = None) -> Tensor:
return linspace(start=0, end=1, steps=num_thresholds, device=device)
102 changes: 67 additions & 35 deletions tests/pre_merge/utils/metrics/test_aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,85 @@


def pytest_generate_tests(metafunc):
if metafunc.function is test_pro:
labels = [
torch.tensor(
labels = [
torch.tensor(
[
[
[
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
[0, 0, 0, 1, 0, 0, 0],
]
),
torch.tensor(
* 400,
]
),
torch.tensor(
[
[
[
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
[0, 1, 0, 1, 0, 1, 0],
]
),
]
preds = torch.arange(2800) / 2800.0
preds = preds.view(1, 1, 400, 7)
* 400,
]
),
]
preds = torch.arange(2800) / 2800.0
preds = preds.view(1, 1, 400, 7)

preds = [preds, preds]
preds = [preds, preds]

fpr_limit = [1 / 3, 1 / 3]
aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)]
fpr_limit = [1 / 3, 1 / 3]
expected_aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)]

# Also test that per-region aupros are averaged
labels.append(torch.cat(labels))
preds.append(torch.cat(preds))
fpr_limit.append(float(np.mean(fpr_limit)))
aupro.append(torch.tensor(np.mean(aupro)))
# Also test that per-region aupros are averaged
labels.append(torch.cat(labels))
preds.append(torch.cat(preds))
fpr_limit.append(float(np.mean(fpr_limit)))
expected_aupro.append(torch.tensor(np.mean(expected_aupro)))

vals = list(zip(labels, preds, fpr_limit, aupro))
metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals)
threshold_count = [
200,
200,
200,
]

if metafunc.function is test_aupro:
vals = list(zip(labels, preds, fpr_limit, expected_aupro))
metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "expected_aupro"), argvalues=vals)
elif metafunc.function is test_binned_aupro:
vals = list(zip(labels, preds, threshold_count))
metafunc.parametrize(argnames=("labels", "preds", "threshold_count"), argvalues=vals)

def test_pro(labels, preds, fpr_limit, aupro):
pro = AUPRO(fpr_limit=fpr_limit)
pro.update(preds, labels)
computed_aupro = pro.compute()

def test_aupro(labels, preds, fpr_limit, expected_aupro):
aupro = AUPRO(fpr_limit=fpr_limit)
aupro.update(preds, labels)
computed_aupro = aupro.compute()

tmp_labels = [label.squeeze().numpy() for label in labels]
tmp_preds = [pred.squeeze().numpy() for pred in preds]
ref_pro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float)
ref_aupro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float)

TOL = 0.001
assert torch.allclose(computed_aupro, expected_aupro, atol=TOL)
assert torch.allclose(computed_aupro, ref_aupro, atol=TOL)


def test_binned_aupro(labels, preds, threshold_count):
aupro = AUPRO()
computed_not_binned_aupro = aupro(preds, labels)

binned_pro = AUPRO(num_thresholds=threshold_count)
computed_binned_aupro = binned_pro(preds, labels)

TOL = 0.001
assert torch.allclose(computed_aupro, aupro, atol=TOL)
assert torch.allclose(computed_aupro, ref_pro, atol=TOL)
assert torch.allclose(aupro, ref_pro, atol=TOL)
# with threshold binning the roc curve computed within the metric is more memory efficient
# but a bit less accurate. So we check the difference in order to validate the binning effect.
assert computed_binned_aupro != computed_not_binned_aupro
assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL)

# test with prediction higher than 1
preds = preds * 2
computed_binned_aupro = binned_pro(preds, labels)
computed_not_binned_aupro = aupro(preds, labels)

# with threshold binning the roc curve computed within the metric is more memory efficient
# but a bit less accurate. So we check the difference in order to validate the binning effect.
assert computed_binned_aupro != computed_not_binned_aupro
assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL)
13 changes: 13 additions & 0 deletions tests/pre_merge/utils/metrics/test_binning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import Tensor, all as torch_all

from anomalib.utils.metrics.binning import thresholds_between_min_and_max, thresholds_between_0_and_1


def test_thresholds_between_min_and_max():
preds = Tensor([1, 10])
assert torch_all(thresholds_between_min_and_max(preds, 2) == preds)


def test_thresholds_between_0_and_1():
expected = Tensor([0, 1])
assert torch_all(thresholds_between_0_and_1(2) == expected)