-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metrics] class based embedding similarity + tests #3358
Merged
Borda
merged 16 commits into
Lightning-AI:master
from
SkafteNicki:metrics/semi_supervised
Sep 11, 2020
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
08791c0
embedding similarity class + test
SkafteNicki 28f278a
fix tests
SkafteNicki a7cb2f4
fix pep8
SkafteNicki ba9b421
add docs
SkafteNicki c96a5df
noindex
awaelchli b9f17db
Update docs/source/metrics.rst
awaelchli dd009c1
Update pytorch_lightning/metrics/self_supervised.py
justusschock 9a7b342
Update pytorch_lightning/metrics/self_supervised.py
justusschock 345e383
suggestions
SkafteNicki 39824ba
changes to init
SkafteNicki 123b768
merge + changelog
SkafteNicki 299d2a8
move __all__
SkafteNicki 9832a8e
fix imports
SkafteNicki 78a8d37
Apply suggestions from code review
Borda 2e8fa6d
assert typo
awaelchli 366e835
change import
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,18 @@ | ||
from pytorch_lightning.metrics.classification import ( | ||
Accuracy, | ||
AveragePrecision, | ||
ConfusionMatrix, | ||
F1, | ||
FBeta, | ||
Recall, | ||
ROC, | ||
AUROC, | ||
DiceCoefficient, | ||
MulticlassPrecisionRecallCurve, | ||
MulticlassROC, | ||
Precision, | ||
PrecisionRecallCurve, | ||
IoU, | ||
) | ||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric | ||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric | ||
from pytorch_lightning.metrics.nlp import BLEUScore | ||
from pytorch_lightning.metrics.regression import ( | ||
MAE, | ||
MSE, | ||
PSNR, | ||
RMSE, | ||
RMSLE, | ||
SSIM | ||
) | ||
from pytorch_lightning.metrics.sklearns import ( | ||
AUC, | ||
SklearnMetric, | ||
) | ||
from pytorch_lightning.metrics.metric import * | ||
from pytorch_lightning.metrics.metric import __all__ as __base_metrics | ||
from pytorch_lightning.metrics.classification import * | ||
from pytorch_lightning.metrics.classification import __all__ as __classification_metrics | ||
from pytorch_lightning.metrics.nlp import * | ||
from pytorch_lightning.metrics.nlp import __all__ as __nlp_metrics | ||
from pytorch_lightning.metrics.regression import * | ||
from pytorch_lightning.metrics.regression import __all__ as __regression_metrics | ||
from pytorch_lightning.metrics.self_supervised import * | ||
from pytorch_lightning.metrics.self_supervised import __all__ as __selfsupervised_metrics | ||
|
||
__classification_metrics = [ | ||
"AUC", | ||
"AUROC", | ||
"Accuracy", | ||
"AveragePrecision", | ||
"ConfusionMatrix", | ||
"DiceCoefficient", | ||
"F1", | ||
"FBeta", | ||
"MulticlassPrecisionRecallCurve", | ||
"MulticlassROC", | ||
"Precision", | ||
"PrecisionRecallCurve", | ||
"ROC", | ||
"Recall", | ||
"IoU", | ||
] | ||
__regression_metrics = [ | ||
"MAE", | ||
"MSE", | ||
"PSNR", | ||
"RMSE", | ||
"RMSLE", | ||
"SSIM" | ||
] | ||
__sequence_metrics = ["BLEUScore"] | ||
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics | ||
|
||
__all__ = __classification_metrics \ | ||
+ __base_metrics \ | ||
+ __nlp_metrics \ | ||
+ __regression_metrics \ | ||
+ __selfsupervised_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any | ||
|
||
import torch | ||
|
||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity | ||
from pytorch_lightning.metrics.metric import TensorMetric | ||
|
||
|
||
SkafteNicki marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class EmbeddingSimilarity(TensorMetric): | ||
""" | ||
Computes similarity between embeddings | ||
|
||
Example: | ||
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) | ||
>>> embedding_similarity(embeddings) | ||
tensor([[0.0000, 1.0000, 0.9759], | ||
[1.0000, 0.0000, 0.9759], | ||
[0.9759, 0.9759, 0.0000]]) | ||
|
||
""" | ||
def __init__( | ||
self, | ||
similarity: str = 'cosine', | ||
zero_diagonal: bool = True, | ||
reduction: str = 'mean', | ||
reduce_group: Any = None | ||
): | ||
""" | ||
Args: | ||
similarity: 'dot' or 'cosine' | ||
reduction: 'none', 'sum', 'mean' (all along dim -1) | ||
zero_diagonal: if True, the diagonals are set to zero | ||
reduce_group: the process group to reduce metric results from DDP | ||
|
||
""" | ||
super().__init__(name='embedding_similarity', | ||
reduce_group=reduce_group) | ||
assert similarity in ('dot', 'cosine') | ||
self.similarity = similarity | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
isinstance(zero_diagonal, bool) | ||
self.zero_diagonal = zero_diagonal | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
asser reduction in ('none', 'sum', 'mean') | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.reduction = reduction | ||
|
||
def forward(self, batch: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
|
||
Args: | ||
batch: tensor containing embeddings with shape (batch_size, dim) | ||
|
||
Return: | ||
A square matrix (batch, batch) with the similarity scores between all elements | ||
If sum or mean are used, then returns (b, 1) with the reduced value for each row | ||
""" | ||
return embedding_similarity(batch, | ||
similarity=self.similarity, | ||
zero_diagonal=self.zero_diagonal, | ||
reduction=self.reduction) | ||
|
||
|
||
__all__ = ['EmbeddingSimilarity'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
import torch | ||
from sklearn.metrics import pairwise | ||
|
||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity | ||
|
||
|
||
@pytest.mark.parametrize('similarity', ['cosine', 'dot']) | ||
@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) | ||
def test_against_sklearn(similarity, reduction): | ||
"""Compare PL metrics to sklearn version.""" | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions | ||
|
||
pl_dist = embedding_similarity(batch, similarity=similarity, | ||
reduction=reduction, zero_diagonal=False) | ||
|
||
def sklearn_embedding_distance(batch, similarity, reduction): | ||
|
||
metric_func = {'cosine': pairwise.cosine_similarity, | ||
'dot': pairwise.linear_kernel}[similarity] | ||
|
||
dist = metric_func(batch, batch) | ||
if reduction == 'mean': | ||
return dist.mean(axis=-1) | ||
if reduction == 'sum': | ||
return dist.sum(axis=-1) | ||
return dist | ||
|
||
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), | ||
similarity=similarity, reduction=reduction) | ||
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) | ||
|
||
assert torch.allclose(sk_dist, pl_dist) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a very common place for
__all__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
codefactor also complains about this, not sure why. I think it should be fine though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the point is that it is very related to importing from packages when you do not want to import all functions
https://stackoverflow.com/questions/44834/can-someone-explain-all-in-python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it is not a common place for
__all__
, normally I would put it at the top of the file, but then codefactor complains about it. But I can move it to the top if that is better.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is not very common to have it in other files than
__init__
so was there a reason to move it from the init?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was changed due to a comment for justus at some point, but lets change it back since it is very uncommon practise