diff --git a/anomalib/models/patchcore/model.py b/anomalib/models/patchcore/model.py index d98e062a25..c1ca1dd072 100644 --- a/anomalib/models/patchcore/model.py +++ b/anomalib/models/patchcore/model.py @@ -74,9 +74,10 @@ def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Image-level anomaly scores """ - confidence = patch_scores[torch.argmax(patch_scores[:, 0])] + max_scores = torch.argmax(patch_scores[:, 0]) + confidence = torch.index_select(patch_scores, 0, max_scores) weights = 1 - (torch.max(torch.exp(confidence)) / torch.sum(torch.exp(confidence))) - score = weights * max(patch_scores[:, 0]) + score = weights * torch.max(patch_scores[:, 0]) return score def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: