Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Reverse Distillation to match official code #1389

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/source/images/reverse_distillation/results/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/reverse_distillation/results/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/reverse_distillation/results/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 13 additions & 60 deletions src/anomalib/models/reverse_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,79 +20,32 @@ During testing, a similar step is followed but this time the cosine distance bet

## Benchmark

All results gathered with seed `42`.

Note: Early Stopping (with patience 3) was enabled during training.
All results gathered with seed `42`, train batch size `16`.

## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

### Image-Level AUC

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.998 | 0.992 |
| Cable | 0.982 | 0.583 |
| Capsule | 0.864 | 0.78 |
| Carpet | 0.996 | 0.539 |
| Grid | 0.941 | 0.975 |
| Hazelnut | 0.978 | 0.817 |
| Leather | 0.878 | 1 |
| Metal_nut | 0.999 | 0.929 |
| Pill | 0.944 | 0.553 |
| Screw | 0.778 | 0.86 |
| Tile | 0.833 | 0.513 |
| Toothbrush | 0.967 | 0.7 |
| Transistor | 0.928 | 0.829 |
| Wood | 0.989 | 0.993 |
| Zipper | 0.968 | 0.787 |
| Average | 0.936 | 0.79 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
abc-125 marked this conversation as resolved.
Show resolved Hide resolved
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.985 | 0.984 | 1.000 | 1.000 | 1.000 | 0.997 | 1.000 | 0.966 | 0.974 | 1.000 | 1.000 | 0.972 | 0.985 | 0.953 | 0.970 | 0.978 |

### Pixel-Level AUC

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.981 | 0.985 |
| Cable | 0.965 | 0.794 |
| Capsule | 0.983 | 0.986 |
| Carpet | 0.989 | 0.99 |
| Grid | 0.964 | 0.99 |
| Hazelnut | 0.988 | 0.983 |
| Leather | 0.984 | 0.995 |
| Metal_nut | 0.971 | 0.979 |
| Pill | 0.975 | 0.977 |
| Screw | 0.987 | 0.989 |
| Tile | 0.867 | 0.953 |
| Toothbrush | 0.99 | 0.979 |
| Transistor | 0.84 | 0.853 |
| Wood | 0.939 | 0.958 |
| Zipper | 0.988 | 0.959 |
| Average | 0.961 | 0.958 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.969 | 0.988 | 0.992 | 0.991 | 0.954 | 0.947 | 0.984 | 0.964 | 0.987 | 0.988 | 0.969 | 0.975 | 0.996 | 0.991 | 0.893 | 0.984 |

### Image F1 Score

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.95 | 0.959 |
| Cable | 0.911 | 0.76 |
| Capsule | 0.933 | 0.905 |
| Carpet | 0.965 | 0.864 |
| Grid | 0.964 | 0.945 |
| Hazelnut | 0.909 | 0.901 |
| Leather | 0.896 | 0.989 |
| Metal_nut | 0.995 | 0.939 |
| Pill | 0.931 | 0.922 |
| Screw | 0.88 | 0.891 |
| Tile | 0.88 | 0.836 |
| Toothbrush | 0.933 | 0.833 |
| Transistor | 0.769 | 0.744 |
| Wood | 0.966 | 0.948 |
| Zipper | 0.944 | 0.926 |
| Average | 0.922 | 0.891 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.976 | 0.977 | 1.000 | 1.000 | 0.994 | 0.992 | 0.984 | 0.930 | 0.982 | 1.000 | 1.000 | 0.967 | 0.963 | 0.952 | 0.927 | 0.975 |

### Sample Results

![Sample Result 1](../../../docs/source/images/reverse_distillation/results/0.png "Sample Result 1")
![Sample Result 1](../../../../docs/source/images/reverse_distillation/results/0.png "Sample Result 1")
abc-125 marked this conversation as resolved.
Show resolved Hide resolved

![Sample Result 2](../../../docs/source/images/reverse_distillation/results/1.png "Sample Result 2")
![Sample Result 2](../../../../docs/source/images/reverse_distillation/results/1.png "Sample Result 2")
abc-125 marked this conversation as resolved.
Show resolved Hide resolved

![Sample Result 3](../../../docs/source/images/reverse_distillation/results/2.png "Sample Result 3")
![Sample Result 3](../../../../docs/source/images/reverse_distillation/results/2.png "Sample Result 3")
abc-125 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def __init__(
self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2)
self.bn3 = norm_layer(256 * block.expansion)

# This is present in the paper but not in the original code. With some initial experiments, removing this leads
# to better results
# self.conv4 = conv1x1(256 * block.expansion * 3, 256 * block.expansion * 3, 1) # x3 as we concatenate 3 layers
# self.bn4 = norm_layer(256 * block.expansion * 3)
# self.conv4 and self.bn4 are from the original code:
# https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/resnet.py#L412
self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1)
self.bn4 = norm_layer(512 * block.expansion)

for module in self.modules():
if isinstance(module, nn.Conv2d):
Expand Down
17 changes: 8 additions & 9 deletions src/anomalib/models/reverse_distillation/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dataset:
path: ./datasets/MVTec
category: bottle
task: segmentation
train_batch_size: 32
train_batch_size: 16
eval_batch_size: 32
inference_batch_size: 32
num_workers: 8
Expand Down Expand Up @@ -35,21 +35,20 @@ model:
- layer1
- layer2
- layer3
early_stopping:
patience: 3
metric: pixel_AUROC
mode: max
# early_stopping: # optional
# patience: 3
# metric: pixel_AUROC
# mode: max
abc-125 marked this conversation as resolved.
Show resolved Hide resolved
beta1: 0.5
beta2: 0.99
beta2: 0.999
normalization_method: min_max # options: [null, min_max, cdf]
anomaly_map_mode: multiply
anomaly_map_mode: add # options: [add, multiply]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
method: adaptive #options: [adaptive, manual]
Expand Down Expand Up @@ -85,7 +84,7 @@ trainer:
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 2 # Don't validate before extracting features.
check_val_every_n_epoch: 200 # Don't validate before extracting features.
fast_dev_run: false
accumulate_grad_batches: 1
max_epochs: 200
Expand Down
17 changes: 10 additions & 7 deletions src/anomalib/models/reverse_distillation/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,20 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
self.save_hyperparameters(hparams)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
"""Configure model-specific non-mandatory callbacks.

Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
if "early_stopping" in self.hparams.model:
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
else:
return []
abc-125 marked this conversation as resolved.
Show resolved Hide resolved
15 changes: 12 additions & 3 deletions src/anomalib/models/reverse_distillation/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class ReverseDistillationLoss(nn.Module):
def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor]) -> Tensor:
"""Computes cosine similarity loss based on features from encoder and decoder.

Based on the official code:
https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/main.py#L33C25-L33C25
Calculates loss from flattened arrays of features, see https://github.com/hq-deng/RD4AD/issues/22

Args:
encoder_features (list[Tensor]): List of features extracted from encoder
decoder_features (list[Tensor]): List of features extracted from decoder
Expand All @@ -23,8 +27,13 @@ def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor]
Tensor: Cosine similarity loss
"""
cos_loss = torch.nn.CosineSimilarity()
losses = list(map(cos_loss, encoder_features, decoder_features))
loss_sum = 0
for loss in losses:
loss_sum += torch.mean(1 - loss) # mean of cosine distance
for item in range(len(encoder_features)):
loss_sum += torch.mean(
1
- cos_loss(
encoder_features[item].view(encoder_features[item].shape[0], -1),
decoder_features[item].view(decoder_features[item].shape[0], -1),
)
)
abc-125 marked this conversation as resolved.
Show resolved Hide resolved
return loss_sum
3 changes: 3 additions & 0 deletions src/anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
class ReverseDistillationModel(nn.Module):
"""Reverse Distillation Model.

To reproduce results in the paper, use torchvision model for the encoder:
self.encoder = torchvision.models.wide_resnet50_2(pretrained=True)

Args:
backbone (str): Name of the backbone used for encoder and decoder
input_size (tuple[int, int]): Size of input image
Expand Down