Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add check_data mode to deal with running with/without nan values #182

Merged
merged 5 commits into from
Aug 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@
The maximum number of calls to the curve fitting function.
_error_metric : str
The error metric to use for post-hoc measures of model fit error.

Run Modes
---------
_debug : bool
Whether the object is set in debug mode.
This should be controlled by using the `set_debug_mode` method.
_check_data : bool
Whether to check added data for NaN or Inf values, and fail out if present.
This should be controlled by using the `set_check_data_mode` method.

Code Notes
----------
Expand Down Expand Up @@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
# The maximum number of calls to the curve fitting function
self._maxfev = 5000
# The error metric to calculate, post model fitting. See `_calc_error` for options
# Note: this is used to check error post-hoc, not an objective function for fitting models
# Note: this is for checking error post fitting, not an objective function for fitting
self._error_metric = 'MAE'
# Set whether in debug mode, in which an error is raised if a model fit fails

## RUN MODES
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
self._debug = False
# Set default check data mode - controls if an error is raised if NaN / Inf data are added
self._check_data = True

# Set internal settings, based on inputs, & initialize data & results attributes
# Set internal settings, based on inputs, and initialize data & results attributes
self._reset_internal_settings()
self._reset_data_results(True, True, True)

Expand Down Expand Up @@ -311,7 +321,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None):
self._reset_data_results(True, True, True)

self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
self._prepare_data(freqs, power_spectrum, freq_range, 1)


def add_settings(self, fooof_settings):
Expand Down Expand Up @@ -431,6 +441,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
# In rare cases, the model fails to fit, and so uses try / except
try:

# If not set to fail on NaN or Inf data at add time, check data here
# This serves as a catch all for curve_fits which will fail given NaN or Inf
# Because FitError's are by default caught, this allows fitting to continue
if not self._check_data:
if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)):
raise FitError("Model fitting was skipped because there are NaN or Inf "
"values in the data, which preclude model fitting.")

# Fit the aperiodic component
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
Expand Down Expand Up @@ -674,7 +692,7 @@ def copy(self):


def set_debug_mode(self, debug):
"""Set whether debug mode, wherein an error is raised if fitting is unsuccessful.
"""Set debug mode, which controls if an error is raised if model fitting is unsuccessful.

Parameters
----------
Expand All @@ -685,6 +703,18 @@ def set_debug_mode(self, debug):
self._debug = debug


def set_check_data_mode(self, check_data):
"""Set check data mode, which controls if an error is raised if NaN or Inf data are added.

Parameters
----------
check_data : bool
Whether to run in check data mode.
"""

self._check_data = check_data


def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""

Expand Down Expand Up @@ -786,7 +816,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
raise FitError("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
except TypeError:
raise FitError("Model fitting failed due to sub-sampling in the robust aperiodic fit.")
raise FitError("Model fitting failed due to sub-sampling "
"in the robust aperiodic fit.")

return aperiodic_params

Expand Down Expand Up @@ -1101,8 +1132,7 @@ def _calc_error(self, metric=None):
raise ValueError(msg)


@staticmethod
def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True):
def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"""Prepare input data for adding to current object.

Parameters
Expand All @@ -1116,8 +1146,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
Frequency range to restrict power spectrum to. If None, keeps the entire range.
spectra_dim : int, optional, default: 1
Dimensionality that the power spectra should have.
verbose : bool, optional
Whether to be verbose in printing out warnings.

Returns
-------
Expand Down Expand Up @@ -1172,7 +1200,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
if freqs[0] == 0.0:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
if verbose:
if self.verbose:
print("\nFOOOF WARNING: Skipping frequency == 0, "
"as this causes a problem with fitting.")

Expand All @@ -1183,12 +1211,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Log power values
power_spectrum = np.log10(power_spectrum)

# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")
if self._check_data:
# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")

return freqs, power_spectrum, freq_range, freq_res

Expand Down
5 changes: 3 additions & 2 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None):
self._reset_group_results()

self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose)
self._prepare_data(freqs, power_spectra, freq_range, 2)


def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
Expand Down Expand Up @@ -476,8 +476,9 @@ def get_fooof(self, ind, regenerate=True):
The FOOOFResults data loaded into a FOOOF object.
"""

# Initialize a FOOOF object, with same settings as current FOOOFGroup
# Initialize a FOOOF object, with same settings & check data mode as current FOOOFGroup
fm = FOOOF(*self.get_settings(), verbose=self.verbose)
fm.set_check_data_mode(self._check_data)

# Add data for specified single power spectrum, if available
# The power spectrum is inverted back to linear, as it is re-logged when added to FOOOF
Expand Down
3 changes: 3 additions & 0 deletions fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def combine_fooofs(fooofs):
if len(fg) == temp_power_spectra.shape[0]:
fg.power_spectra = temp_power_spectra

# Set the check data mode, as True if any of the inputs have it on, False otherwise
fg.set_check_data_mode(any([getattr(f_obj, '_check_data') for f_obj in fooofs]))

# Add data information information
fg.add_meta_data(fooofs[0].get_meta_data())

Expand Down
23 changes: 21 additions & 2 deletions fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fooof.core.items import OBJ_DESC
from fooof.core.errors import FitError
from fooof.core.utils import group_three
from fooof.sim import gen_power_spectrum
from fooof.sim import gen_freqs, gen_power_spectrum
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
from fooof.core.errors import DataError, NoDataError, InconsistentDataError

Expand Down Expand Up @@ -364,7 +364,7 @@ def raise_runtime_error(*args, **kwargs):
assert np.all(np.isnan(getattr(tfm, result)))

def test_fooof_debug():
"""Test FOOOF fit failure in debug mode."""
"""Test FOOOF in debug mode, including with fit failures."""

tfm = FOOOF(verbose=False)
tfm._maxfev = 5
Expand All @@ -374,3 +374,22 @@ def test_fooof_debug():

with raises(FitError):
tfm.fit(*gen_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4]))

def test_fooof_check_data():
"""Test FOOOF in with check data mode turned off, including with NaN data."""

tfm = FOOOF(verbose=False)

tfm.set_check_data_mode(False)
assert tfm._check_data is False

# Add data, with check data turned off
# In check data mode, adding data with NaN should run
freqs = gen_freqs([3, 50], 0.5)
powers = np.ones_like(freqs) * np.nan
tfm.add_data(freqs, powers)
assert tfm.has_data

# Model fitting should execute, but return a null model fit, given the NaNs, without failing
tfm.fit()
assert not tfm.has_model