Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add check_data mode to deal with running with/without nan values #182

Merged
merged 5 commits into from
Aug 9, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@
The maximum number of calls to the curve fitting function.
_error_metric : str
The error metric to use for post-hoc measures of model fit error.

Run Modes
---------
_debug : bool
Whether the object is set in debug mode.
This should be controlled by using the `set_debug_mode` method.
_check_data : bool
Whether to check added data for NaN or Inf values, and fail out if present.
This should be controlled by using the `set_check_data_mode` method.

Code Notes
----------
Expand Down Expand Up @@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
# The maximum number of calls to the curve fitting function
self._maxfev = 5000
# The error metric to calculate, post model fitting. See `_calc_error` for options
# Note: this is used to check error post-hoc, not an objective function for fitting models
# Note: this is for checking error post fitting, not an objective function for fitting
self._error_metric = 'MAE'
# Set whether in debug mode, in which an error is raised if a model fit fails

## RUN MODES
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
self._debug = False
# Set default check data mode - controls if an error is raised if NaN / Inf data are added
self._check_data = True

# Set internal settings, based on inputs, & initialize data & results attributes
# Set internal settings, based on inputs, and initialize data & results attributes
self._reset_internal_settings()
self._reset_data_results(True, True, True)

Expand Down Expand Up @@ -311,7 +321,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None):
self._reset_data_results(True, True, True)

self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
self._prepare_data(freqs, power_spectrum, freq_range, 1)


def add_settings(self, fooof_settings):
Expand Down Expand Up @@ -431,6 +441,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
# In rare cases, the model fails to fit, and so uses try / except
try:

# If not set to fail on NaN or Inf data at add time, check data here
# This serves as a catch all for curve_fits which will fail given NaN or Inf
# Because FitError's are by default caught, this allows fitting to continue
if not self._check_data:
if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)):
raise FitError("There are NaN or Inf values in the data, "
ryanhammonds marked this conversation as resolved.
Show resolved Hide resolved
"which preclude model fitting.")

# Fit the aperiodic component
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
Expand Down Expand Up @@ -674,7 +692,7 @@ def copy(self):


def set_debug_mode(self, debug):
"""Set whether debug mode, wherein an error is raised if fitting is unsuccessful.
"""Set debug mode, which controls if an error is raised if model fitting is unsuccessful.

Parameters
----------
Expand All @@ -685,6 +703,18 @@ def set_debug_mode(self, debug):
self._debug = debug


def set_check_data_mode(self, check_data):
"""Set check data mode, which controls if an error is raised if NaN or Inf data are added.

Parameters
----------
check_data : bool
Whether to run in check data mode.
"""

self._check_data = check_data


def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""

Expand Down Expand Up @@ -1101,8 +1131,7 @@ def _calc_error(self, metric=None):
raise ValueError(msg)


@staticmethod
def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True):
def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"""Prepare input data for adding to current object.

Parameters
Expand All @@ -1116,8 +1145,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
Frequency range to restrict power spectrum to. If None, keeps the entire range.
spectra_dim : int, optional, default: 1
Dimensionality that the power spectra should have.
verbose : bool, optional
Whether to be verbose in printing out warnings.

Returns
-------
Expand Down Expand Up @@ -1172,7 +1199,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
if freqs[0] == 0.0:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
if verbose:
if self.verbose:
print("\nFOOOF WARNING: Skipping frequency == 0, "
"as this causes a problem with fitting.")

Expand All @@ -1183,12 +1210,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Log power values
power_spectrum = np.log10(power_spectrum)

# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")
if self._check_data:
# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")

return freqs, power_spectrum, freq_range, freq_res

Expand Down
2 changes: 1 addition & 1 deletion fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None):
self._reset_group_results()

self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose)
self._prepare_data(freqs, power_spectra, freq_range, 2)


def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
Expand Down
23 changes: 21 additions & 2 deletions fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fooof.core.items import OBJ_DESC
from fooof.core.errors import FitError
from fooof.core.utils import group_three
from fooof.sim import gen_power_spectrum
from fooof.sim import gen_freqs, gen_power_spectrum
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
from fooof.core.errors import DataError, NoDataError, InconsistentDataError

Expand Down Expand Up @@ -364,7 +364,7 @@ def raise_runtime_error(*args, **kwargs):
assert np.all(np.isnan(getattr(tfm, result)))

def test_fooof_debug():
"""Test FOOOF fit failure in debug mode."""
"""Test FOOOF in debug mode, including with fit failures."""

tfm = FOOOF(verbose=False)
tfm._maxfev = 5
Expand All @@ -374,3 +374,22 @@ def test_fooof_debug():

with raises(FitError):
tfm.fit(*gen_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4]))

def test_fooof_check_data():
"""Test FOOOF in with check data mode turned off, including with NaN data."""

tfm = FOOOF(verbose=False)

tfm.set_check_data_mode(False)
assert tfm._check_data is False

# Add data, with check data turned off
# In check data mode, adding data with NaN should run
freqs = gen_freqs([3, 50], 0.5)
powers = np.ones_like(freqs) * np.nan
tfm.add_data(freqs, powers)
assert tfm.has_data

# Model fitting should execute, but return a null model fit, given the NaNs, without failing
tfm.fit()
assert not tfm.has_model