Skip to content

Commit

Permalink
add nan_strategy "disable" to disable nan checks
Browse files Browse the repository at this point in the history
  • Loading branch information
crand-mbe committed Feb 4, 2025
1 parent 767026e commit 11868d8
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BaseAggregator(Metric):
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
"""

Expand All @@ -62,7 +62,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
allowed_nan_strategy = ("error", "warn", "ignore")
allowed_nan_strategy = ("error", "warn", "ignore", "disable")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
Expand All @@ -81,25 +81,26 @@ def _cast_and_nan_check_input(
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)

nans = torch.isnan(x)
if weight is not None:
nans_weight = torch.isnan(weight)
else:
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
if nans.any() or nans_weight.any():
if self.nan_strategy == "error":
raise RuntimeError("Encountered `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
if self.nan_strategy != "disable":
nans = torch.isnan(x)
if weight is not None:
nans_weight = torch.isnan(weight)
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
if nans.any() or nans_weight.any():
if self.nan_strategy == "error":
raise RuntimeError("Encountered `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy

return x.to(self.dtype), weight.to(self.dtype)

Expand Down Expand Up @@ -543,22 +544,24 @@ def __init__(
)
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")

def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None:
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = None) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
weight: Either a float or tensor containing weights for calculating
the average. Shape of weight should be able to broadcast with
the shape of `value`. Default to `1.0` corresponding to simple
the shape of `value`. Default to None corresponding to simple
harmonic average.
"""
# broadcast weight to value shape
if not isinstance(value, Tensor):
value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
if weight is None:
weight = torch.ones_like(value)
elif not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
weight = torch.broadcast_to(weight, value.shape)
value, weight = self._cast_and_nan_check_input(value, weight)
Expand Down

0 comments on commit 11868d8

Please sign in to comment.