Skip to content

Commit

Permalink
compatibility TM v0.4 (#8206)
Browse files Browse the repository at this point in the history
* tm compatible

* format

* chlog

* ..

* destroy distributed workaround

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
Borda and awaelchli authored Jun 30, 2021
1 parent e5fcc3d commit afc69e4
Show file tree
Hide file tree
Showing 31 changed files with 228 additions and 130 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))

## [1.3.6] - 2021-06-DD

- Fix compatibility TorchMetrics v0.4 ([#8206](https://github.com/PyTorchLightning/pytorch-lightning/pull/8206))


## [1.3.5] - 2021-06-08

Expand Down
15 changes: 11 additions & 4 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
from torchmetrics import F1 as _F1
from torchmetrics import FBeta as _FBeta

from pytorch_lightning.metrics.utils import deprecated_metrics
from pytorch_lightning.metrics.utils import (
_TORCHMETRICS_GREATER_EQUAL_0_4,
_TORCHMETRICS_LOWER_THAN_0_4,
deprecated_metrics,
void,
)


class FBeta(_FBeta):

@deprecated_metrics(target=_FBeta)
@deprecated_metrics(target=_FBeta, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(target=_FBeta, args_mapping={"multilabel": None}, skip_if=_TORCHMETRICS_LOWER_THAN_0_4)
def __init__(
self,
num_classes: int,
Expand All @@ -44,7 +50,8 @@ def __init__(

class F1(_F1):

@deprecated_metrics(target=_F1)
@deprecated_metrics(target=_F1, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(target=_F1, args_mapping={"multilabel": None}, skip_if=_TORCHMETRICS_LOWER_THAN_0_4)
def __init__(
self,
num_classes: int,
Expand All @@ -61,4 +68,4 @@ def __init__(
.. deprecated::
Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
"""
_ = num_classes, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group
void(num_classes, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group)
29 changes: 26 additions & 3 deletions pytorch_lightning/metrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,25 @@
from torchmetrics import Precision as _Precision
from torchmetrics import Recall as _Recall

from pytorch_lightning.metrics.utils import deprecated_metrics
from pytorch_lightning.metrics.utils import (
_TORCHMETRICS_GREATER_EQUAL_0_4,
_TORCHMETRICS_LOWER_THAN_0_4,
deprecated_metrics,
void,
)


class Precision(_Precision):

@deprecated_metrics(target=_Precision)
@deprecated_metrics(target=_Precision, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(
target=_Precision,
args_mapping={
"multilabel": None,
"is_multiclass": None
},
skip_if=_TORCHMETRICS_LOWER_THAN_0_4
)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -49,7 +62,13 @@ def __init__(

class Recall(_Recall):

@deprecated_metrics(target=_Recall)
@deprecated_metrics(target=_Recall, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(
target=_Recall, args_mapping={
"multilabel": None,
"is_multiclass": None
}, skip_if=_TORCHMETRICS_LOWER_THAN_0_4
)
def __init__(
self,
num_classes: Optional[int] = None,
Expand All @@ -71,3 +90,7 @@ def __init__(
.. deprecated::
Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0.
"""
void(
num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass,
compute_on_step, dist_sync_on_step, process_group, dist_sync_fn
)
21 changes: 18 additions & 3 deletions pytorch_lightning/metrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,25 @@

from torchmetrics import StatScores as _StatScores

from pytorch_lightning.metrics.utils import deprecated_metrics
from pytorch_lightning.metrics.utils import (
_TORCHMETRICS_GREATER_EQUAL_0_4,
_TORCHMETRICS_LOWER_THAN_0_4,
deprecated_metrics,
void,
)


class StatScores(_StatScores):

@deprecated_metrics(target=_StatScores)
@deprecated_metrics(target=_StatScores, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(
target=_StatScores,
args_mapping={
"multilabel": None,
"is_multiclass": "multiclass"
},
skip_if=_TORCHMETRICS_LOWER_THAN_0_4
)
def __init__(
self,
threshold: float = 0.5,
Expand All @@ -41,5 +54,7 @@ def __init__(
.. deprecated::
Use :class:`~torchmetrics.StatScores`. Will be removed in v1.5.0.
"""
_ = threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, \
void(
threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step,
dist_sync_on_step, process_group, dist_sync_fn
)
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Callable, Union

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.metric import CompositionalMetric as _CompositionalMetric

Expand All @@ -26,8 +26,8 @@ class CompositionalMetric(_CompositionalMetric):
def __init__(
self,
operator: Callable,
metric_a: Union[Metric, int, float, torch.Tensor],
metric_b: Union[Metric, int, float, torch.Tensor, None],
metric_a: Union[Metric, int, float, Tensor],
metric_b: Union[Metric, int, float, Tensor, None],
):
"""
.. deprecated::
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor
from torchmetrics.functional import accuracy as _accuracy

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_accuracy)
def accuracy(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
) -> torch.Tensor:
) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0.
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/functional/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from torch import Tensor
from torchmetrics.functional import auc as _auc

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_auc)
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.auc`. Will be removed in v1.5.0.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@
# limitations under the License.
from typing import Optional, Sequence

import torch
from torch import Tensor
from torchmetrics.functional import auroc as _auroc

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_auroc)
def auroc(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.5.0.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
# limitations under the License.
from typing import List, Optional, Sequence, Union

import torch
from torch import Tensor
from torchmetrics.functional import average_precision as _average_precision

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_average_precision)
def average_precision(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[List[Tensor], Tensor]:
"""
.. deprecated::
Use :func:`torchmetrics.functional.average_precision`. Will be removed in v1.5.0.
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor
from torchmetrics.functional import confusion_matrix as _confusion_matrix

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_confusion_matrix)
def confusion_matrix(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5
) -> torch.Tensor:
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.confusion_matrix`. Will be removed in v1.5.0.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
# limitations under the License.
from typing import Sequence, Union

import torch
from torch import Tensor
from torchmetrics.functional import explained_variance as _explained_variance

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_explained_variance)
def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
) -> Union[Tensor, Sequence[Tensor]]:
"""
.. deprecated::
Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0.
Expand Down
28 changes: 17 additions & 11 deletions pytorch_lightning/metrics/functional/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,44 @@
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor
from torchmetrics.functional import f1 as _f1
from torchmetrics.functional import fbeta as _fbeta

from pytorch_lightning.metrics.utils import deprecated_metrics
from pytorch_lightning.metrics.utils import (
_TORCHMETRICS_GREATER_EQUAL_0_4,
_TORCHMETRICS_LOWER_THAN_0_4,
deprecated_metrics,
)


@deprecated_metrics(target=_fbeta)
@deprecated_metrics(target=_fbeta, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(target=_fbeta, args_mapping={"multilabel": None}, skip_if=_TORCHMETRICS_LOWER_THAN_0_4)
def fbeta(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: Optional[bool] = None
) -> torch.Tensor:
) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0.
Use :func:`torchmetrics.functional.fbeta`. Will be removed in v1.5.0.
"""


@deprecated_metrics(target=_f1)
@deprecated_metrics(target=_f1, skip_if=_TORCHMETRICS_GREATER_EQUAL_0_4)
@deprecated_metrics(target=_f1, args_mapping={"multilabel": None}, skip_if=_TORCHMETRICS_LOWER_THAN_0_4)
def f1(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
threshold: float = 0.5,
average: str = "micro",
multilabel: Optional[bool] = None
) -> torch.Tensor:
) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.f1`. Will be removed in v1.5.0.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from torchmetrics.functional import hamming_distance as _hamming_distance

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_hamming_distance)
def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.functional.hamming_distance`. Will be removed in v1.5.0.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/image_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor
from torchmetrics.functional import image_gradients as _image_gradients

from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated_metrics(target=_image_gradients)
def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
"""
.. deprecated::
Use :func:`torchmetrics.functional.image_gradients`. Will be removed in v1.5.0.
Expand Down
Loading

0 comments on commit afc69e4

Please sign in to comment.