Skip to content

Commit

Permalink
Optimize anomaly score calculation for PatchCore for both num_neighb… (
Browse files Browse the repository at this point in the history
…#633)

* Optimized anomaly score calculation for PatchCore for both num_neighbors ==1 and > 1

* Optimized anomaly score calculation for PatchCore for both num_neighbors ==1 and > 1

* Optimized anomaly score calculation for PatchCore for both num_neighbors ==1 and > 1

* Optimized anomaly score calculation for PatchCore for both num_neighbors ==1 and > 1

* Optimized anomaly score calculation for PatchCore for both num_neighbors ==1 and > 1
  • Loading branch information
VdLMV authored Nov 2, 2022
1 parent 2f0a87c commit 573fbb3
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 13 deletions.
8 changes: 4 additions & 4 deletions anomalib/data/utils/generators/perlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_perlin_noise_2d(shape, res):
"""Fractal perlin noise."""

def f(t):
return 6 * t**5 - 15 * t**4 + 10 * t**3
return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3

delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand All @@ -68,7 +68,7 @@ def f(t):
def random_2d_perlin(
shape: Tuple,
res: Tuple[Union[int, Tensor], Union[int, Tensor]],
fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3,
fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3,
) -> Union[np.ndarray, Tensor]:
"""Returns a random 2d perlin noise array.
Expand All @@ -90,7 +90,7 @@ def random_2d_perlin(
return result


def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
"""Generate a random image containing Perlin noise. Numpy version."""
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand All @@ -116,7 +116,7 @@ def dot(grad, shift):
return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])


def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
"""Generate a random image containing Perlin noise. PyTorch version."""
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/cflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor)
torch.Tensor: Log probability
"""
ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi))
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u ** 2, 1) + logdet_j
return logp


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def johnson_lindenstrauss_min_dim(self, n_samples: int, eps: float = 0.1):
eps (float, optional): Minimum distortion rate. Defaults to 0.1.
"""

denominator = (eps**2 / 2) - (eps**3 / 3)
denominator = (eps ** 2 / 2) - (eps ** 3 / 3)
return (4 * np.log(n_samples) / denominator).astype(np.int64)

def fit(self, embedding: Tensor) -> "SparseRandomProjection":
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/components/stats/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fit(self, dataset: Tensor) -> None:

cov_mat = self.cov(dataset.T)
inv_cov_mat = torch.linalg.inv(cov_mat)
inv_cov = inv_cov_mat / factor**2
inv_cov = inv_cov_mat / factor ** 2

# transform data to account for bandwidth
bw_transform = torch.linalg.cholesky(inv_cov)
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(self, hidden_variables: List[Tensor]) -> Tensor:
"""
flow_maps: List[Tensor] = []
for hidden_variable in hidden_variables:
log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5
log_prob = -torch.mean(hidden_variable ** 2, dim=1, keepdim=True) * 0.5
prob = torch.exp(log_prob)
flow_map = F.interpolate(
input=-prob,
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/fastflow/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Te
"""
loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable
for (hidden_variable, jacobian) in zip(hidden_variables, jacobians):
loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian)
loss += torch.mean(0.5 * torch.sum(hidden_variable ** 2, dim=(1, 2, 3)) - jacobian)
return loss
2 changes: 1 addition & 1 deletion anomalib/models/ganomaly/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(

# Calculate input channel size to recreate inverse pyramid
exp_factor = math.ceil(math.log(min(input_size) // 2, 2)) - 2
n_input_features = n_features * (2**exp_factor)
n_input_features = n_features * (2 ** exp_factor)

# CNN layer for latent vector input
self.latent_input.add_module(
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
category: bottle
image_size: 224
train_batch_size: 32
test_batch_size: 1
test_batch_size: 32
num_workers: 8
transform_config:
train: null
Expand Down
11 changes: 9 additions & 2 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> Tuple[Tensor
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
else:
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores, locations

def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embedding: Tensor) -> Tensor:
Expand All @@ -168,6 +172,9 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
Tensor: Image-level anomaly scores
"""

# Don't need to compute weights if num_neighbors is 1
if self.num_neighbors == 1:
return patch_scores.amax(1)
# 1. Find the patch with the largest distance to it's nearest neighbor in each image
max_patches = torch.argmax(patch_scores, dim=1) # (m^test,* in the paper)
# 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank
Expand All @@ -179,7 +186,7 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(embedding[max_patches].unsqueeze(1), self.memory_bank[support_samples], p=2.0)
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze()))[..., 0]
weights = (1 - F.softmax(distances.squeeze(), 1))[..., 0]
# 6. Apply the weight factor to the score
score = weights * score # S^* in the paper
return score

0 comments on commit 573fbb3

Please sign in to comment.