Skip to content

Commit

Permalink
Fix single update in pearson corrcoef (#2019)
Browse files Browse the repository at this point in the history
* fix
* changelog
  • Loading branch information
SkafteNicki authored Aug 23, 2023
1 parent 42c748b commit 96b9439
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)


- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _pearson_corrcoef_update(
# Data checking
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)
cond = n_prior.mean() > 0

n_obs = preds.shape[0]
cond = n_prior.mean() > 0 or n_obs == 1

if cond:
mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs)
my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs)
Expand All @@ -67,7 +67,6 @@ def _pearson_corrcoef_update(
if cond:
var_x += ((preds - mx_new) * (preds - mean_x)).sum(0)
var_y += ((target - my_new) * (target - mean_y)).sum(0)

else:
var_x += preds.var(0) * (n_obs - 1)
var_y += target.var(0) * (n_obs - 1)
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,20 @@ def test_pearsons_warning_on_small_input(dtype, scale):
target = scale * torch.randn(100, dtype=dtype)
with pytest.warns(UserWarning, match="The variance of predictions or target is close to zero.*"):
pearson_corrcoef(preds, target)


def test_single_sample_update():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/2014."""
metric = PearsonCorrCoef()

# Works
metric(torch.tensor([3.0, -0.5, 2.0, 7.0]), torch.tensor([2.5, 0.0, 2.0, 8.0]))
res1 = metric.compute()
metric.reset()

metric(torch.tensor([3.0]), torch.tensor([2.5]))
metric(torch.tensor([-0.5]), torch.tensor([0.0]))
metric(torch.tensor([2.0]), torch.tensor([2.0]))
metric(torch.tensor([7.0]), torch.tensor([8.0]))
res2 = metric.compute()
assert torch.allclose(res1, res2)

0 comments on commit 96b9439

Please sign in to comment.