Skip to content

Commit

Permalink
[Metrics] class based embedding similarity + tests (#3358)
Browse files Browse the repository at this point in the history
* embedding similarity class + test

* fix tests

* fix pep8

* add docs

* noindex

* Update docs/source/metrics.rst

* Update pytorch_lightning/metrics/self_supervised.py

Co-authored-by: Rohit Gupta <[email protected]>

* Update pytorch_lightning/metrics/self_supervised.py

Co-authored-by: Rohit Gupta <[email protected]>

* suggestions

* changes to init

* move __all__

* fix imports

* Apply suggestions from code review

* assert typo

* change import

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
  • Loading branch information
6 people authored Sep 11, 2020
1 parent 70af47d commit 93cf6d0
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 9 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/))

- Added `EmbeddingSimilarity` metric:
* functional interface ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349))
* class based interface + tests ([#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down Expand Up @@ -142,7 +146,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))

Expand Down
7 changes: 6 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ DiceCoefficient
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
:noindex:

EmbeddingSimilarity
^^^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.self_supervised.EmbeddingSimilarity
:noindex:

F1
^^

Expand Down Expand Up @@ -629,4 +635,3 @@ MeanTweedieDeviance (sk)

.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
:noindex:

9 changes: 8 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.nlp import BLEUScore
from pytorch_lightning.metrics.self_supervised import EmbeddingSimilarity
from pytorch_lightning.metrics.regression import (
MAE,
MSE,
Expand Down Expand Up @@ -56,4 +57,10 @@
"SSIM"
]
__sequence_metrics = ["BLEUScore"]
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics
__selfsuper_metrics = ["EmbeddingSimilarity"]

__all__ = __regression_metrics \
+ __classification_metrics \
+ __selfsuper_metrics \
+ __sequence_metrics \
+ ["SklearnMetric"]
9 changes: 3 additions & 6 deletions pytorch_lightning/metrics/functional/self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def embedding_similarity(
if reduction == 'mean':
sqr_mtx = sqr_mtx.mean(dim=-1)

return sqr_mtx

if reduction == 'sum':
sqr_mtx = sqr_mtx.sum(dim=-1)

if __name__ == '__main__':
a = torch.rand(3, 5)

print(embedding_similarity(a, 'cosine'))
return sqr_mtx
73 changes: 73 additions & 0 deletions pytorch_lightning/metrics/self_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch

from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity
from pytorch_lightning.metrics.metric import TensorMetric


class EmbeddingSimilarity(TensorMetric):
"""
Computes similarity between embeddings
Example:
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])
"""
def __init__(
self,
similarity: str = 'cosine',
zero_diagonal: bool = True,
reduction: str = 'mean',
reduce_group: Any = None
):
"""
Args:
similarity: 'dot' or 'cosine'
reduction: 'none', 'sum', 'mean' (all along dim -1)
zero_diagonal: if True, the diagonals are set to zero
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(name='embedding_similarity',
reduce_group=reduce_group)
assert similarity in ('dot', 'cosine')
self.similarity = similarity
isinstance(zero_diagonal, bool)
self.zero_diagonal = zero_diagonal
assert reduction in ('none', 'sum', 'mean')
self.reduction = reduction

def forward(self, batch: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
batch: tensor containing embeddings with shape (batch_size, dim)
Return:
A square matrix (batch, batch) with the similarity scores between all elements
If sum or mean are used, then returns (b, 1) with the reduced value for each row
"""
return embedding_similarity(batch,
similarity=self.similarity,
zero_diagonal=self.zero_diagonal,
reduction=self.reduction)
35 changes: 35 additions & 0 deletions tests/metrics/functional/test_self_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch
from sklearn.metrics import pairwise

from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity


@pytest.mark.parametrize('similarity', ['cosine', 'dot'])
@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum'])
def test_against_sklearn(similarity, reduction):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions

pl_dist = embedding_similarity(batch, similarity=similarity,
reduction=reduction, zero_diagonal=False)

def sklearn_embedding_distance(batch, similarity, reduction):

metric_func = {'cosine': pairwise.cosine_similarity,
'dot': pairwise.linear_kernel}[similarity]

dist = metric_func(batch, batch)
if reduction == 'mean':
return dist.mean(axis=-1)
if reduction == 'sum':
return dist.sum(axis=-1)
return dist

sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(),
similarity=similarity, reduction=reduction)
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)

assert torch.allclose(sk_dist, pl_dist)

0 comments on commit 93cf6d0

Please sign in to comment.