Skip to content

Commit

Permalink
Merge pull request #182 from fooof-tools/nans
Browse files Browse the repository at this point in the history
[ENH] Add `check_data` mode to deal with running with/without nan values
  • Loading branch information
TomDonoghue authored Aug 9, 2020
2 parents 1c3b030 + 815ae18 commit 7a520c4
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 21 deletions.
63 changes: 46 additions & 17 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@
The maximum number of calls to the curve fitting function.
_error_metric : str
The error metric to use for post-hoc measures of model fit error.
Run Modes
---------
_debug : bool
Whether the object is set in debug mode.
This should be controlled by using the `set_debug_mode` method.
_check_data : bool
Whether to check added data for NaN or Inf values, and fail out if present.
This should be controlled by using the `set_check_data_mode` method.
Code Notes
----------
Expand Down Expand Up @@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
# The maximum number of calls to the curve fitting function
self._maxfev = 5000
# The error metric to calculate, post model fitting. See `_calc_error` for options
# Note: this is used to check error post-hoc, not an objective function for fitting models
# Note: this is for checking error post fitting, not an objective function for fitting
self._error_metric = 'MAE'
# Set whether in debug mode, in which an error is raised if a model fit fails

## RUN MODES
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
self._debug = False
# Set default check data mode - controls if an error is raised if NaN / Inf data are added
self._check_data = True

# Set internal settings, based on inputs, & initialize data & results attributes
# Set internal settings, based on inputs, and initialize data & results attributes
self._reset_internal_settings()
self._reset_data_results(True, True, True)

Expand Down Expand Up @@ -312,7 +322,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
clear_results=self.has_model and clear_results)

self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
self._prepare_data(freqs, power_spectrum, freq_range, 1)


def add_settings(self, fooof_settings):
Expand Down Expand Up @@ -432,6 +442,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
# In rare cases, the model fails to fit, and so uses try / except
try:

# If not set to fail on NaN or Inf data at add time, check data here
# This serves as a catch all for curve_fits which will fail given NaN or Inf
# Because FitError's are by default caught, this allows fitting to continue
if not self._check_data:
if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)):
raise FitError("Model fitting was skipped because there are NaN or Inf "
"values in the data, which preclude model fitting.")

# Fit the aperiodic component
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
Expand Down Expand Up @@ -675,7 +693,7 @@ def copy(self):


def set_debug_mode(self, debug):
"""Set whether debug mode, wherein an error is raised if fitting is unsuccessful.
"""Set debug mode, which controls if an error is raised if model fitting is unsuccessful.
Parameters
----------
Expand All @@ -686,6 +704,18 @@ def set_debug_mode(self, debug):
self._debug = debug


def set_check_data_mode(self, check_data):
"""Set check data mode, which controls if an error is raised if NaN or Inf data are added.
Parameters
----------
check_data : bool
Whether to run in check data mode.
"""

self._check_data = check_data


def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""

Expand Down Expand Up @@ -795,7 +825,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
raise FitError("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
except TypeError:
raise FitError("Model fitting failed due to sub-sampling in the robust aperiodic fit.")
raise FitError("Model fitting failed due to sub-sampling "
"in the robust aperiodic fit.")

return aperiodic_params

Expand Down Expand Up @@ -1110,8 +1141,7 @@ def _calc_error(self, metric=None):
raise ValueError(msg)


@staticmethod
def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True):
def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"""Prepare input data for adding to current object.
Parameters
Expand All @@ -1125,8 +1155,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
Frequency range to restrict power spectrum to. If None, keeps the entire range.
spectra_dim : int, optional, default: 1
Dimensionality that the power spectra should have.
verbose : bool, optional
Whether to be verbose in printing out warnings.
Returns
-------
Expand Down Expand Up @@ -1181,7 +1209,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
if freqs[0] == 0.0:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
if verbose:
if self.verbose:
print("\nFOOOF WARNING: Skipping frequency == 0, "
"as this causes a problem with fitting.")

Expand All @@ -1192,12 +1220,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
# Log power values
power_spectrum = np.log10(power_spectrum)

# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")
if self._check_data:
# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")

return freqs, power_spectrum, freq_range, freq_res

Expand Down
5 changes: 3 additions & 2 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None):
self._reset_group_results()

self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose)
self._prepare_data(freqs, power_spectra, freq_range, 2)


def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
Expand Down Expand Up @@ -476,8 +476,9 @@ def get_fooof(self, ind, regenerate=True):
The FOOOFResults data loaded into a FOOOF object.
"""

# Initialize a FOOOF object, with same settings as current FOOOFGroup
# Initialize a FOOOF object, with same settings & check data mode as current FOOOFGroup
fm = FOOOF(*self.get_settings(), verbose=self.verbose)
fm.set_check_data_mode(self._check_data)

# Add data for specified single power spectrum, if available
# The power spectrum is inverted back to linear, as it is re-logged when added to FOOOF
Expand Down
3 changes: 3 additions & 0 deletions fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def combine_fooofs(fooofs):
if len(fg) == temp_power_spectra.shape[0]:
fg.power_spectra = temp_power_spectra

# Set the check data mode, as True if any of the inputs have it on, False otherwise
fg.set_check_data_mode(any([getattr(f_obj, '_check_data') for f_obj in fooofs]))

# Add data information information
fg.add_meta_data(fooofs[0].get_meta_data())

Expand Down
23 changes: 21 additions & 2 deletions fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fooof.core.items import OBJ_DESC
from fooof.core.errors import FitError
from fooof.core.utils import group_three
from fooof.sim import gen_power_spectrum
from fooof.sim import gen_freqs, gen_power_spectrum
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
from fooof.core.errors import DataError, NoDataError, InconsistentDataError

Expand Down Expand Up @@ -396,7 +396,7 @@ def raise_runtime_error(*args, **kwargs):
assert np.all(np.isnan(getattr(tfm, result)))

def test_fooof_debug():
"""Test FOOOF fit failure in debug mode."""
"""Test FOOOF in debug mode, including with fit failures."""

tfm = FOOOF(verbose=False)
tfm._maxfev = 5
Expand All @@ -406,3 +406,22 @@ def test_fooof_debug():

with raises(FitError):
tfm.fit(*gen_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4]))

def test_fooof_check_data():
"""Test FOOOF in with check data mode turned off, including with NaN data."""

tfm = FOOOF(verbose=False)

tfm.set_check_data_mode(False)
assert tfm._check_data is False

# Add data, with check data turned off
# In check data mode, adding data with NaN should run
freqs = gen_freqs([3, 50], 0.5)
powers = np.ones_like(freqs) * np.nan
tfm.add_data(freqs, powers)
assert tfm.has_data

# Model fitting should execute, but return a null model fit, given the NaNs, without failing
tfm.fit()
assert not tfm.has_model

0 comments on commit 7a520c4

Please sign in to comment.