From 41ea64afd4193c996e154cb05083f6356ba1965b Mon Sep 17 00:00:00 2001 From: yann-cv Date: Tue, 20 Jun 2023 14:22:48 +0200 Subject: [PATCH 01/15] Add the capability to compute binned AUPRO. --- src/anomalib/utils/metrics/aupro.py | 19 +++++++++++++++++-- tests/pre_merge/utils/metrics/test_aupro.py | 20 +++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index ba8608fd3d..5decefc9dc 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -5,6 +5,7 @@ from __future__ import annotations +from functools import partial from typing import Any, Callable import torch @@ -30,6 +31,11 @@ class AUPRO(Metric): full_state_update: bool = False preds: list[Tensor] target: list[Tensor] + # The threshold used to compute the binned version of the AUPRO which is less accurate + # but more memory efficient. Warning: Contrary to AUROC or AUPR, here the predictions + # are not scaled between 0 and 1 by self calling the sigmoid function if some predictions + # are greater than 1. + thresholds: Tensor | None def __init__( self, @@ -38,6 +44,7 @@ def __init__( process_group: Any | None = None, dist_sync_fn: Callable | None = None, fpr_limit: float = 0.3, + thresholds: int | list[float] | Tensor | None = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -50,6 +57,11 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) + if thresholds is None: + self.thresholds = thresholds + else: + self.register_buffer("thresholds", thresholds) + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with new values. @@ -96,9 +108,12 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso Returns: tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. """ + # initialize the roc curve with the specified thresholds + # target being forced to be 0 or 1, we can use the binary roc curve + roc_in_pro = partial(roc, task="binary", thresholds=self.thresholds) # compute the global fpr-size - fpr: Tensor = roc(preds, target)[0] # only need fpr + fpr: Tensor = roc_in_pro(preds, target)[0] # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. @@ -120,7 +135,7 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso mask = cca == label # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other # label in labels as FPs. We also don't need to return the thresholds - _fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1] + _fpr, _tpr = roc_in_pro(preds[background | mask], mask[background | mask])[:-1] # catch edge-case where ROC only has fpr vals > self.fpr_limit if _fpr[_fpr <= self.fpr_limit].max() == 0: diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index 88466d0566..75692fa200 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -41,11 +41,16 @@ def pytest_generate_tests(metafunc): fpr_limit.append(float(np.mean(fpr_limit))) aupro.append(torch.tensor(np.mean(aupro))) - vals = list(zip(labels, preds, fpr_limit, aupro)) - metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals) + thresholds = [ + torch.linspace(0, 1, steps=50), + torch.linspace(0, 1, steps=50), + ] + vals = list(zip(labels, preds, thresholds, fpr_limit, aupro)) + + metafunc.parametrize(argnames=("labels", "preds", "thresholds", "fpr_limit", "aupro"), argvalues=vals) -def test_pro(labels, preds, fpr_limit, aupro): +def test_pro(labels, preds, thresholds, fpr_limit, aupro): pro = AUPRO(fpr_limit=fpr_limit) pro.update(preds, labels) computed_aupro = pro.compute() @@ -58,3 +63,12 @@ def test_pro(labels, preds, fpr_limit, aupro): assert torch.allclose(computed_aupro, aupro, atol=TOL) assert torch.allclose(computed_aupro, ref_pro, atol=TOL) assert torch.allclose(aupro, ref_pro, atol=TOL) + + binned_pro = AUPRO(fpr_limit=fpr_limit, thresholds=thresholds) + binned_pro.update(preds, labels) + computed_binned_aupro = binned_pro.compute() + + assert computed_binned_aupro != computed_aupro + assert torch.allclose(computed_aupro, aupro, atol=TOL) + + From 739ae0edbdb4b2f9b95d36bf96c09ad6d988f51e Mon Sep 17 00:00:00 2001 From: yann-cv Date: Tue, 20 Jun 2023 14:24:44 +0200 Subject: [PATCH 02/15] fix linting --- tests/pre_merge/utils/metrics/test_aupro.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index 75692fa200..8e75e4e7ed 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -70,5 +70,3 @@ def test_pro(labels, preds, thresholds, fpr_limit, aupro): assert computed_binned_aupro != computed_aupro assert torch.allclose(computed_aupro, aupro, atol=TOL) - - From 4d85d329e986441313d008213a629b39b81ac793 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Wed, 28 Jun 2023 17:23:02 +0200 Subject: [PATCH 03/15] use directly binary_roc --- src/anomalib/utils/metrics/aupro.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 5decefc9dc..345708e004 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -12,7 +12,8 @@ from matplotlib.figure import Figure from torch import Tensor from torchmetrics import Metric -from torchmetrics.functional import auc, roc +from torchmetrics.functional import auc +from torchmetrics.functional.classification import binary_roc from torchmetrics.utilities.data import dim_zero_cat from anomalib.utils.metrics.pro import ( @@ -110,7 +111,7 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso """ # initialize the roc curve with the specified thresholds # target being forced to be 0 or 1, we can use the binary roc curve - roc_in_pro = partial(roc, task="binary", thresholds=self.thresholds) + roc_in_pro = partial(binary_roc, thresholds=self.thresholds) # compute the global fpr-size fpr: Tensor = roc_in_pro(preds, target)[0] # only need fpr From b2c547c7736d2a1f0d7e0a0e97820a9b69aa426f Mon Sep 17 00:00:00 2001 From: yann-cv Date: Wed, 28 Jun 2023 17:23:20 +0200 Subject: [PATCH 04/15] update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92e765fe22..c21296d831 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- AUPRO binning capability by @yann-cv + ### Changed ### Deprecated From e6fe63e955941b52a3685a74c3a46d2ad8254782 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Thu, 29 Jun 2023 09:49:34 +0200 Subject: [PATCH 05/15] improve test by doing 2 different ones (aupro and binned aupro) + renamed few variables in the tests --- tests/pre_merge/utils/metrics/test_aupro.py | 99 ++++++++++++--------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index 8e75e4e7ed..1c7f41d327 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -8,65 +8,76 @@ def pytest_generate_tests(metafunc): - if metafunc.function is test_pro: - labels = [ - torch.tensor( + labels = [ + torch.tensor( + [ [ - [ - [0, 0, 0, 1, 0, 0, 0], - ] - * 400, + [0, 0, 0, 1, 0, 0, 0], ] - ), - torch.tensor( + * 400, + ] + ), + torch.tensor( + [ [ - [ - [0, 1, 0, 1, 0, 1, 0], - ] - * 400, + [0, 1, 0, 1, 0, 1, 0], ] - ), - ] - preds = torch.arange(2800) / 2800.0 - preds = preds.view(1, 1, 400, 7) + * 400, + ] + ), + ] + preds = torch.arange(2800) / 2800.0 + preds = preds.view(1, 1, 400, 7) - preds = [preds, preds] + preds = [preds, preds] - fpr_limit = [1 / 3, 1 / 3] - aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)] + fpr_limit = [1 / 3, 1 / 3] + expected_aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)] - # Also test that per-region aupros are averaged - labels.append(torch.cat(labels)) - preds.append(torch.cat(preds)) - fpr_limit.append(float(np.mean(fpr_limit))) - aupro.append(torch.tensor(np.mean(aupro))) + # Also test that per-region aupros are averaged + labels.append(torch.cat(labels)) + preds.append(torch.cat(preds)) + fpr_limit.append(float(np.mean(fpr_limit))) + expected_aupro.append(torch.tensor(np.mean(expected_aupro))) - thresholds = [ - torch.linspace(0, 1, steps=50), - torch.linspace(0, 1, steps=50), - ] - vals = list(zip(labels, preds, thresholds, fpr_limit, aupro)) + thresholds = [ + torch.linspace(0, 1, steps=200), + torch.linspace(0, 1, steps=200), + torch.linspace(0, 1, steps=200), + ] - metafunc.parametrize(argnames=("labels", "preds", "thresholds", "fpr_limit", "aupro"), argvalues=vals) + if metafunc.function is test_aupro: + vals = list(zip(labels, preds, fpr_limit, expected_aupro)) + metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals) + elif metafunc.function is test_binned_aupro: + vals = list(zip(labels, preds, thresholds)) + metafunc.parametrize(argnames=("labels", "preds", "thresholds"), argvalues=vals) -def test_pro(labels, preds, thresholds, fpr_limit, aupro): - pro = AUPRO(fpr_limit=fpr_limit) - pro.update(preds, labels) - computed_aupro = pro.compute() +def test_aupro(labels, preds, fpr_limit, expected_aupro): + aupro = AUPRO(fpr_limit=fpr_limit) + aupro.update(preds, labels) + computed_aupro = aupro.compute() tmp_labels = [label.squeeze().numpy() for label in labels] tmp_preds = [pred.squeeze().numpy() for pred in preds] - ref_pro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float) + ref_aupro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float) TOL = 0.001 - assert torch.allclose(computed_aupro, aupro, atol=TOL) - assert torch.allclose(computed_aupro, ref_pro, atol=TOL) - assert torch.allclose(aupro, ref_pro, atol=TOL) + assert torch.allclose(computed_aupro, expected_aupro, atol=TOL) + assert torch.allclose(computed_aupro, ref_aupro, atol=TOL) + assert torch.allclose(aupro, ref_aupro, atol=TOL) - binned_pro = AUPRO(fpr_limit=fpr_limit, thresholds=thresholds) - binned_pro.update(preds, labels) - computed_binned_aupro = binned_pro.compute() - assert computed_binned_aupro != computed_aupro - assert torch.allclose(computed_aupro, aupro, atol=TOL) +def test_binned_aupro(labels, preds, thresholds): + aupro = AUPRO() + computed_not_binned_aupro = aupro(preds, labels) + + binned_pro = AUPRO(thresholds=thresholds) + computed_binned_aupro = binned_pro(preds, labels) + + TOL = 0.001 + # with threshold binning the roc curve computed within the metric is more memory efficient + # but a bit less accurate. So we check the difference in order to validate the binning effect. + assert computed_binned_aupro != computed_not_binned_aupro + assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) \ No newline at end of file From 18a4abbaaaacfb4a4ac6bc6d64e695e85fd92aeb Mon Sep 17 00:00:00 2001 From: yann-cv Date: Wed, 5 Jul 2023 16:49:36 +0200 Subject: [PATCH 06/15] only allow num_thresholds as input + fix tests + add threshold computing utilities --- src/anomalib/utils/metrics/aupro.py | 40 +++++++++++++-------- src/anomalib/utils/metrics/binning.py | 17 +++++++++ tests/pre_merge/utils/metrics/test_aupro.py | 31 ++++++++++------ 3 files changed, 63 insertions(+), 25 deletions(-) create mode 100644 src/anomalib/utils/metrics/binning.py diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 345708e004..80d2bd3b2d 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -20,6 +20,7 @@ connected_components_cpu, connected_components_gpu, ) +from .binning import thresholds_between_min_and_max, thresholds_between_0_and_1 from .plotting_utils import plot_figure @@ -32,11 +33,13 @@ class AUPRO(Metric): full_state_update: bool = False preds: list[Tensor] target: list[Tensor] - # The threshold used to compute the binned version of the AUPRO which is less accurate - # but more memory efficient. Warning: Contrary to AUROC or AUPR, here the predictions - # are not scaled between 0 and 1 by self calling the sigmoid function if some predictions - # are greater than 1. - thresholds: Tensor | None + # When not None, the computation is performed in constant-memory by computing the roc curve + # for fixed thresholds buckets/thresholds. + # Warning: The thresholds are evenly distributed between the min and max predictions + # if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1. + # This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed + # and the roc curve is computed with deactivated formatting + num_thresholds: int | None def __init__( self, @@ -45,7 +48,7 @@ def __init__( process_group: Any | None = None, dist_sync_fn: Callable | None = None, fpr_limit: float = 0.3, - thresholds: int | list[float] | Tensor | None = None, + num_thresholds: int | None = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -57,11 +60,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) - - if thresholds is None: - self.thresholds = thresholds - else: - self.register_buffer("thresholds", thresholds) + self.num_thresholds = num_thresholds def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with new values. @@ -109,9 +108,22 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso Returns: tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. """ - # initialize the roc curve with the specified thresholds - # target being forced to be 0 or 1, we can use the binary roc curve - roc_in_pro = partial(binary_roc, thresholds=self.thresholds) + if self.num_thresholds is not None: + # binary_roc is applying a sigmoid on the predictions before computing the roc curve + # when some predictions are out of [0, 1], the binning between min and max predictions + # cannot be applied in that case. This can be removed when + # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and + # the roc curve is computed with deactivated formatting. + + if all((0 <= preds) * (preds <= 1)): + thresholds = thresholds_between_min_and_max(preds, self.num_thresholds) + else: + thresholds = thresholds_between_0_and_1(self.num_thresholds) + else: + thresholds = None + + # initialize the roc curve from the specified threshold count + roc_in_pro = partial(binary_roc, thresholds=thresholds,) # compute the global fpr-size fpr: Tensor = roc_in_pro(preds, target)[0] # only need fpr diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py new file mode 100644 index 0000000000..69017fcd28 --- /dev/null +++ b/src/anomalib/utils/metrics/binning.py @@ -0,0 +1,17 @@ +from torch import Tensor, linspace, all + + +def thresholds_between_min_and_max(preds: Tensor, num_thresholds: int = 100) -> Tensor: + return linspace( + start=preds.min(), + end=preds.max(), + steps=num_thresholds, + ) + + +def thresholds_between_0_and_1(num_thresholds: int = 100) -> Tensor: + return linspace( + start=0, + end=1, + steps=num_thresholds, + ) \ No newline at end of file diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index 1c7f41d327..5e17394926 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -40,18 +40,18 @@ def pytest_generate_tests(metafunc): fpr_limit.append(float(np.mean(fpr_limit))) expected_aupro.append(torch.tensor(np.mean(expected_aupro))) - thresholds = [ - torch.linspace(0, 1, steps=200), - torch.linspace(0, 1, steps=200), - torch.linspace(0, 1, steps=200), + threshold_count = [ + 200, + 200, + 200, ] if metafunc.function is test_aupro: vals = list(zip(labels, preds, fpr_limit, expected_aupro)) - metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals) + metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "expected_aupro"), argvalues=vals) elif metafunc.function is test_binned_aupro: - vals = list(zip(labels, preds, thresholds)) - metafunc.parametrize(argnames=("labels", "preds", "thresholds"), argvalues=vals) + vals = list(zip(labels, preds, threshold_count)) + metafunc.parametrize(argnames=("labels", "preds", "threshold_count"), argvalues=vals) def test_aupro(labels, preds, fpr_limit, expected_aupro): @@ -66,18 +66,27 @@ def test_aupro(labels, preds, fpr_limit, expected_aupro): TOL = 0.001 assert torch.allclose(computed_aupro, expected_aupro, atol=TOL) assert torch.allclose(computed_aupro, ref_aupro, atol=TOL) - assert torch.allclose(aupro, ref_aupro, atol=TOL) -def test_binned_aupro(labels, preds, thresholds): +def test_binned_aupro(labels, preds, threshold_count): aupro = AUPRO() computed_not_binned_aupro = aupro(preds, labels) - binned_pro = AUPRO(thresholds=thresholds) + binned_pro = AUPRO(num_thresholds=threshold_count) computed_binned_aupro = binned_pro(preds, labels) TOL = 0.001 # with threshold binning the roc curve computed within the metric is more memory efficient # but a bit less accurate. So we check the difference in order to validate the binning effect. assert computed_binned_aupro != computed_not_binned_aupro - assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) \ No newline at end of file + assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) + + # test with prediction higher than 1 + preds = preds * 2 + computed_binned_aupro = binned_pro(preds, labels) + computed_not_binned_aupro = aupro(preds, labels) + + # with threshold binning the roc curve computed within the metric is more memory efficient + # but a bit less accurate. So we check the difference in order to validate the binning effect. + assert computed_binned_aupro != computed_not_binned_aupro + assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) From a382be659d93765eb5d241c0dff27c8f5ba500a4 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Wed, 5 Jul 2023 16:53:15 +0200 Subject: [PATCH 07/15] add binning tests --- tests/pre_merge/utils/metrics/test_binning.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/pre_merge/utils/metrics/test_binning.py diff --git a/tests/pre_merge/utils/metrics/test_binning.py b/tests/pre_merge/utils/metrics/test_binning.py new file mode 100644 index 0000000000..ab6831ef7d --- /dev/null +++ b/tests/pre_merge/utils/metrics/test_binning.py @@ -0,0 +1,14 @@ +from torch import Tensor, linspace, all + +from anomalib.utils.metrics.binning import thresholds_between_min_and_max, \ + thresholds_between_0_and_1 + + +def test_thresholds_between_min_and_max(): + preds = Tensor([1, 10]) + assert all(thresholds_between_min_and_max(preds, 2) == preds) + + +def test_thresholds_between_0_and_1(): + expected = Tensor([0, 1]) + assert all(thresholds_between_0_and_1(2) == expected) From 63a2acac688d74b8277c40681598d3f09aabecec Mon Sep 17 00:00:00 2001 From: yann-cv Date: Thu, 6 Jul 2023 11:02:15 +0200 Subject: [PATCH 08/15] use binary roc directly --- src/anomalib/utils/metrics/aupro.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 80d2bd3b2d..21ef548953 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -122,11 +122,8 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso else: thresholds = None - # initialize the roc curve from the specified threshold count - roc_in_pro = partial(binary_roc, thresholds=thresholds,) - # compute the global fpr-size - fpr: Tensor = roc_in_pro(preds, target)[0] # only need fpr + fpr: Tensor = binary_roc(preds, target, thresholds=thresholds,)[0] # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. @@ -148,7 +145,7 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso mask = cca == label # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other # label in labels as FPs. We also don't need to return the thresholds - _fpr, _tpr = roc_in_pro(preds[background | mask], mask[background | mask])[:-1] + _fpr, _tpr = binary_roc(preds[background | mask], mask[background | mask], thresholds=thresholds,)[:-1] # catch edge-case where ROC only has fpr vals > self.fpr_limit if _fpr[_fpr <= self.fpr_limit].max() == 0: From a0d6fbd6103396dc49486230dba28092307c2b30 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Thu, 6 Jul 2023 11:05:35 +0200 Subject: [PATCH 09/15] remove unused import and rename some --- src/anomalib/utils/metrics/binning.py | 2 +- tests/pre_merge/utils/metrics/test_binning.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py index 69017fcd28..69e6c9a7dd 100644 --- a/src/anomalib/utils/metrics/binning.py +++ b/src/anomalib/utils/metrics/binning.py @@ -1,4 +1,4 @@ -from torch import Tensor, linspace, all +from torch import Tensor, linspace def thresholds_between_min_and_max(preds: Tensor, num_thresholds: int = 100) -> Tensor: diff --git a/tests/pre_merge/utils/metrics/test_binning.py b/tests/pre_merge/utils/metrics/test_binning.py index ab6831ef7d..c49d1a4236 100644 --- a/tests/pre_merge/utils/metrics/test_binning.py +++ b/tests/pre_merge/utils/metrics/test_binning.py @@ -1,14 +1,16 @@ -from torch import Tensor, linspace, all +from torch import Tensor, all as torch_all -from anomalib.utils.metrics.binning import thresholds_between_min_and_max, \ +from anomalib.utils.metrics.binning import ( + thresholds_between_min_and_max, thresholds_between_0_and_1 +) def test_thresholds_between_min_and_max(): preds = Tensor([1, 10]) - assert all(thresholds_between_min_and_max(preds, 2) == preds) + assert torch_all(thresholds_between_min_and_max(preds, 2) == preds) def test_thresholds_between_0_and_1(): expected = Tensor([0, 1]) - assert all(thresholds_between_0_and_1(2) == expected) + assert torch_all(thresholds_between_0_and_1(2) == expected) From 553380afde2abbeee5da0378db615a8cc98e1dae Mon Sep 17 00:00:00 2001 From: yann-cv Date: Thu, 6 Jul 2023 11:53:18 +0200 Subject: [PATCH 10/15] device for thresholds --- src/anomalib/utils/metrics/aupro.py | 16 +++++++++++----- src/anomalib/utils/metrics/binning.py | 21 ++++++++++----------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 21ef548953..41a3568654 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -5,7 +5,6 @@ from __future__ import annotations -from functools import partial from typing import Any, Callable import torch @@ -116,14 +115,17 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso # the roc curve is computed with deactivated formatting. if all((0 <= preds) * (preds <= 1)): - thresholds = thresholds_between_min_and_max(preds, self.num_thresholds) + thresholds = thresholds_between_min_and_max( + preds, self.num_thresholds, self.device + ) else: - thresholds = thresholds_between_0_and_1(self.num_thresholds) + thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device) + else: thresholds = None # compute the global fpr-size - fpr: Tensor = binary_roc(preds, target, thresholds=thresholds,)[0] # only need fpr + fpr: Tensor = binary_roc(preds=preds, target=target, thresholds=thresholds,)[0] # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. @@ -145,7 +147,11 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso mask = cca == label # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other # label in labels as FPs. We also don't need to return the thresholds - _fpr, _tpr = binary_roc(preds[background | mask], mask[background | mask], thresholds=thresholds,)[:-1] + _fpr, _tpr = binary_roc( + preds=preds[background | mask], + target=mask[background | mask], + thresholds=thresholds, + )[:-1] # catch edge-case where ROC only has fpr vals > self.fpr_limit if _fpr[_fpr <= self.fpr_limit].max() == 0: diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py index 69e6c9a7dd..be1299aa64 100644 --- a/src/anomalib/utils/metrics/binning.py +++ b/src/anomalib/utils/metrics/binning.py @@ -1,17 +1,16 @@ -from torch import Tensor, linspace +from torch import Tensor, linspace, device as torch_device -def thresholds_between_min_and_max(preds: Tensor, num_thresholds: int = 100) -> Tensor: +def thresholds_between_min_and_max( + preds: Tensor, num_thresholds: int = 100, device: None | torch_device = None +) -> Tensor: return linspace( - start=preds.min(), - end=preds.max(), - steps=num_thresholds, + start=preds.min(), end=preds.max(), steps=num_thresholds, device=device ) -def thresholds_between_0_and_1(num_thresholds: int = 100) -> Tensor: - return linspace( - start=0, - end=1, - steps=num_thresholds, - ) \ No newline at end of file + +def thresholds_between_0_and_1( + num_thresholds: int = 100, device: None | torch_device = None +) -> Tensor: + return linspace(start=0, end=1, steps=num_thresholds, device=device) \ No newline at end of file From f4fd62a49683f131bc68d3cf6345e38094b815f1 Mon Sep 17 00:00:00 2001 From: Yann-CV Date: Fri, 7 Jul 2023 17:25:02 +0200 Subject: [PATCH 11/15] fix linting --- src/anomalib/utils/metrics/aupro.py | 14 +++++++++----- src/anomalib/utils/metrics/binning.py | 14 +++++--------- tests/pre_merge/utils/metrics/test_binning.py | 5 +---- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 41a3568654..1c806c21f1 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -19,8 +19,8 @@ connected_components_cpu, connected_components_gpu, ) -from .binning import thresholds_between_min_and_max, thresholds_between_0_and_1 +from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max from .plotting_utils import plot_figure @@ -115,9 +115,7 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso # the roc curve is computed with deactivated formatting. if all((0 <= preds) * (preds <= 1)): - thresholds = thresholds_between_min_and_max( - preds, self.num_thresholds, self.device - ) + thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device) else: thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device) @@ -125,7 +123,13 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso thresholds = None # compute the global fpr-size - fpr: Tensor = binary_roc(preds=preds, target=target, thresholds=thresholds,)[0] # only need fpr + fpr: Tensor = binary_roc( + preds=preds, + target=target, + thresholds=thresholds, + )[ + 0 + ] # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py index be1299aa64..f2201cf409 100644 --- a/src/anomalib/utils/metrics/binning.py +++ b/src/anomalib/utils/metrics/binning.py @@ -1,16 +1,12 @@ -from torch import Tensor, linspace, device as torch_device +from torch import Tensor, linspace +from torch import device as torch_device def thresholds_between_min_and_max( preds: Tensor, num_thresholds: int = 100, device: None | torch_device = None ) -> Tensor: - return linspace( - start=preds.min(), end=preds.max(), steps=num_thresholds, device=device - ) + return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) - -def thresholds_between_0_and_1( - num_thresholds: int = 100, device: None | torch_device = None -) -> Tensor: - return linspace(start=0, end=1, steps=num_thresholds, device=device) \ No newline at end of file +def thresholds_between_0_and_1(num_thresholds: int = 100, device: None | torch_device = None) -> Tensor: + return linspace(start=0, end=1, steps=num_thresholds, device=device) diff --git a/tests/pre_merge/utils/metrics/test_binning.py b/tests/pre_merge/utils/metrics/test_binning.py index c49d1a4236..d256d95a7a 100644 --- a/tests/pre_merge/utils/metrics/test_binning.py +++ b/tests/pre_merge/utils/metrics/test_binning.py @@ -1,9 +1,6 @@ from torch import Tensor, all as torch_all -from anomalib.utils.metrics.binning import ( - thresholds_between_min_and_max, - thresholds_between_0_and_1 -) +from anomalib.utils.metrics.binning import thresholds_between_min_and_max, thresholds_between_0_and_1 def test_thresholds_between_min_and_max(): From 650e4e86cc5ea5d1bb6bd6918035484286da29e9 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Fri, 7 Jul 2023 17:35:21 +0200 Subject: [PATCH 12/15] use torch.all --- src/anomalib/utils/metrics/aupro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 1c806c21f1..ad2a615e5c 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -114,7 +114,7 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and # the roc curve is computed with deactivated formatting. - if all((0 <= preds) * (preds <= 1)): + if torch.all((0 <= preds) * (preds <= 1)): thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device) else: thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device) From b82bedf864d94f26c79095125b1ddc93bca0a3f7 Mon Sep 17 00:00:00 2001 From: Yann-CV Date: Fri, 7 Jul 2023 17:55:04 +0200 Subject: [PATCH 13/15] fix linting --- src/anomalib/utils/metrics/aupro.py | 24 +++++++++++++++++++----- src/anomalib/utils/metrics/binning.py | 4 +++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index ad2a615e5c..247b07f5d3 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -5,6 +5,8 @@ from __future__ import annotations +import time +from contextlib import contextmanager from typing import Any, Callable import torch @@ -24,6 +26,14 @@ from .plotting_utils import plot_figure +@contextmanager +def observe_execution_time(description): + start = time.monotonic() + yield + delta = time.monotonic() - start + print(f"{description}: took {delta} seconds") + + class AUPRO(Metric): """Area under per region overlap (AUPRO) Metric.""" @@ -201,12 +211,15 @@ def _compute(self) -> tuple[Tensor, Tensor]: Returns: tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. """ - - cca = self.perform_cca().flatten() + with observe_execution_time("cca"): + cca = self.perform_cca().flatten() target = dim_zero_cat(self.target).flatten() preds = dim_zero_cat(self.preds).flatten() - return self.compute_pro(cca=cca, target=target, preds=preds) + with observe_execution_time("pro"): + pro = self.compute_pro(cca=cca, target=target, preds=preds) + + return pro def compute(self) -> Tensor: """Fist compute PRO curve, then compute and scale area under the curve. @@ -216,8 +229,9 @@ def compute(self) -> Tensor: """ fpr, tpr = self._compute() - aupro = auc(fpr, tpr, reorder=True) - aupro = aupro / fpr[-1] # normalize the area + with observe_execution_time("aupro"): + aupro = auc(fpr, tpr, reorder=True) + aupro = aupro / fpr[-1] # normalize the area return aupro diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py index f2201cf409..81878fbbbf 100644 --- a/src/anomalib/utils/metrics/binning.py +++ b/src/anomalib/utils/metrics/binning.py @@ -1,9 +1,11 @@ +from typing import Optional + from torch import Tensor, linspace from torch import device as torch_device def thresholds_between_min_and_max( - preds: Tensor, num_thresholds: int = 100, device: None | torch_device = None + preds: Tensor, num_thresholds: int = 100, device: Optional[torch_device] = None ) -> Tensor: return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) From 893579db1c3d9c8f370f24d4919d88963fd12f86 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Mon, 10 Jul 2023 13:44:03 +0200 Subject: [PATCH 14/15] fix linting --- src/anomalib/utils/metrics/binning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py index 81878fbbbf..f92e0b20ab 100644 --- a/src/anomalib/utils/metrics/binning.py +++ b/src/anomalib/utils/metrics/binning.py @@ -10,5 +10,5 @@ def thresholds_between_min_and_max( return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) -def thresholds_between_0_and_1(num_thresholds: int = 100, device: None | torch_device = None) -> Tensor: +def thresholds_between_0_and_1(num_thresholds: int = 100, device: Optional[torch_device] = None) -> Tensor: return linspace(start=0, end=1, steps=num_thresholds, device=device) From 6feeb0b7a089adc2287c66815cb35b6a0dbacbf8 Mon Sep 17 00:00:00 2001 From: yann-cv Date: Mon, 10 Jul 2023 13:45:26 +0200 Subject: [PATCH 15/15] remove observe time --- src/anomalib/utils/metrics/aupro.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 247b07f5d3..ad2a615e5c 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -5,8 +5,6 @@ from __future__ import annotations -import time -from contextlib import contextmanager from typing import Any, Callable import torch @@ -26,14 +24,6 @@ from .plotting_utils import plot_figure -@contextmanager -def observe_execution_time(description): - start = time.monotonic() - yield - delta = time.monotonic() - start - print(f"{description}: took {delta} seconds") - - class AUPRO(Metric): """Area under per region overlap (AUPRO) Metric.""" @@ -211,15 +201,12 @@ def _compute(self) -> tuple[Tensor, Tensor]: Returns: tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. """ - with observe_execution_time("cca"): - cca = self.perform_cca().flatten() + + cca = self.perform_cca().flatten() target = dim_zero_cat(self.target).flatten() preds = dim_zero_cat(self.preds).flatten() - with observe_execution_time("pro"): - pro = self.compute_pro(cca=cca, target=target, preds=preds) - - return pro + return self.compute_pro(cca=cca, target=target, preds=preds) def compute(self) -> Tensor: """Fist compute PRO curve, then compute and scale area under the curve. @@ -229,9 +216,8 @@ def compute(self) -> Tensor: """ fpr, tpr = self._compute() - with observe_execution_time("aupro"): - aupro = auc(fpr, tpr, reorder=True) - aupro = aupro / fpr[-1] # normalize the area + aupro = auc(fpr, tpr, reorder=True) + aupro = aupro / fpr[-1] # normalize the area return aupro