Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dropna #178

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions s2spy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,16 @@ def __init__( # noqa: PLR0913
self._trend: dict
self._is_fit = False

def fit(self, data: Union[xr.DataArray, xr.Dataset]) -> None:
def fit(self, data: Union[xr.DataArray, xr.Dataset], dropna=False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def fit(self, data: Union[xr.DataArray, xr.Dataset], dropna=False) -> None:
def fit(self, data: Union[xr.DataArray, xr.Dataset], dropna: bool = False) -> None:

Also add a type here to please mypy.

"""Fit this Preprocessor to input data.

Args:
data: Input data for fitting.
dropna: If True, drop all NaN values from the data before preprocessing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dropna: If True, drop all NaN values from the data before preprocessing.
dropna: If True, drop all NaN values from the data along time axis before preprocessing.

Point out that this dropna works only for the time axis.

"""
_check_input_data(data)
if dropna:
data = data.dropna("time")
if self._window_size not in [None, 1]:
data_rolling = data.rolling(
dim={"time": self._window_size}, # type: ignore
Expand All @@ -228,7 +231,6 @@ def fit(self, data: Union[xr.DataArray, xr.Dataset]) -> None:

if self._subtract_climatology:
self._climatology = _get_climatology(data_rolling, self._timescale)

if self._detrend is not None:
if self._subtract_climatology:
deseasonalized = _subtract_climatology(
Expand All @@ -241,16 +243,20 @@ def fit(self, data: Union[xr.DataArray, xr.Dataset]) -> None:
self._is_fit = True

def transform(
self, data: Union[xr.DataArray, xr.Dataset]
self, data: Union[xr.DataArray, xr.Dataset], dropna=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self, data: Union[xr.DataArray, xr.Dataset], dropna=False
self, data: Union[xr.DataArray, xr.Dataset], dropna: bool = False

) -> Union[xr.DataArray, xr.Dataset]:
"""Apply the preprocessing steps to the input data.

Args:
data: Input data to perform preprocessing.
dropna: If True, drop all NaN values from the data before preprocessing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dropna: If True, drop all NaN values from the data before preprocessing.
dropna: If True, drop all NaN values from the data along time axis before preprocessing.


Returns:
Preprocessed data.
"""
if dropna:
data = data.dropna("time")

if not self._is_fit:
raise ValueError(
"The preprocessor has to be fit to data before a transform"
Expand Down
Loading