diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index e5b007b7..7636f34f 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -39,9 +39,15 @@ The maximum number of calls to the curve fitting function. _error_metric : str The error metric to use for post-hoc measures of model fit error. + +Run Modes +--------- _debug : bool Whether the object is set in debug mode. This should be controlled by using the `set_debug_mode` method. +_check_data : bool + Whether to check added data for NaN or Inf values, and fail out if present. + This should be controlled by using the `set_check_data_mode` method. Code Notes ---------- @@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h # The maximum number of calls to the curve fitting function self._maxfev = 5000 # The error metric to calculate, post model fitting. See `_calc_error` for options - # Note: this is used to check error post-hoc, not an objective function for fitting models + # Note: this is for checking error post fitting, not an objective function for fitting self._error_metric = 'MAE' - # Set whether in debug mode, in which an error is raised if a model fit fails + + ## RUN MODES + # Set default debug mode - controls if an error is raised if model fitting is unsuccessful self._debug = False + # Set default check data mode - controls if an error is raised if NaN / Inf data are added + self._check_data = True - # Set internal settings, based on inputs, & initialize data & results attributes + # Set internal settings, based on inputs, and initialize data & results attributes self._reset_internal_settings() self._reset_data_results(True, True, True) @@ -311,7 +321,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None): self._reset_data_results(True, True, True) self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose) + self._prepare_data(freqs, power_spectrum, freq_range, 1) def add_settings(self, fooof_settings): @@ -431,6 +441,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # In rare cases, the model fails to fit, and so uses try / except try: + # If not set to fail on NaN or Inf data at add time, check data here + # This serves as a catch all for curve_fits which will fail given NaN or Inf + # Because FitError's are by default caught, this allows fitting to continue + if not self._check_data: + if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): + raise FitError("Model fitting was skipped because there are NaN or Inf " + "values in the data, which preclude model fitting.") + # Fit the aperiodic component self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) @@ -674,7 +692,7 @@ def copy(self): def set_debug_mode(self, debug): - """Set whether debug mode, wherein an error is raised if fitting is unsuccessful. + """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. Parameters ---------- @@ -685,6 +703,18 @@ def set_debug_mode(self, debug): self._debug = debug + def set_check_data_mode(self, check_data): + """Set check data mode, which controls if an error is raised if NaN or Inf data are added. + + Parameters + ---------- + check_data : bool + Whether to run in check data mode. + """ + + self._check_data = check_data + + def _check_width_limits(self): """Check and warn about peak width limits / frequency resolution interaction.""" @@ -786,7 +816,8 @@ def _robust_ap_fit(self, freqs, power_spectrum): raise FitError("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") except TypeError: - raise FitError("Model fitting failed due to sub-sampling in the robust aperiodic fit.") + raise FitError("Model fitting failed due to sub-sampling " + "in the robust aperiodic fit.") return aperiodic_params @@ -1101,8 +1132,7 @@ def _calc_error(self, metric=None): raise ValueError(msg) - @staticmethod - def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True): + def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): """Prepare input data for adding to current object. Parameters @@ -1116,8 +1146,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True Frequency range to restrict power spectrum to. If None, keeps the entire range. spectra_dim : int, optional, default: 1 Dimensionality that the power spectra should have. - verbose : bool, optional - Whether to be verbose in printing out warnings. Returns ------- @@ -1172,7 +1200,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) - if verbose: + if self.verbose: print("\nFOOOF WARNING: Skipping frequency == 0, " "as this causes a problem with fitting.") @@ -1183,12 +1211,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True # Log power values power_spectrum = np.log10(power_spectrum) - # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): - raise DataError("The input power spectra data, after logging, contains NaNs or Infs. " - "This will cause the fitting to fail. " - "One reason this can happen is if inputs are already logged. " - "Inputs data should be in linear spacing, not log.") + if self._check_data: + # Check if there are any infs / nans, and raise an error if so + if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + raise DataError("The input power spectra data, after logging, contains NaNs or Infs. " + "This will cause the fitting to fail. " + "One reason this can happen is if inputs are already logged. " + "Inputs data should be in linear spacing, not log.") return freqs, power_spectrum, freq_range, freq_res diff --git a/fooof/objs/group.py b/fooof/objs/group.py index 8bd293d3..c008aa03 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None): self._reset_group_results() self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose) + self._prepare_data(freqs, power_spectra, freq_range, 2) def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): @@ -476,8 +476,9 @@ def get_fooof(self, ind, regenerate=True): The FOOOFResults data loaded into a FOOOF object. """ - # Initialize a FOOOF object, with same settings as current FOOOFGroup + # Initialize a FOOOF object, with same settings & check data mode as current FOOOFGroup fm = FOOOF(*self.get_settings(), verbose=self.verbose) + fm.set_check_data_mode(self._check_data) # Add data for specified single power spectrum, if available # The power spectrum is inverted back to linear, as it is re-logged when added to FOOOF diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py index 253a4f13..a2e02938 100644 --- a/fooof/objs/utils.py +++ b/fooof/objs/utils.py @@ -178,6 +178,9 @@ def combine_fooofs(fooofs): if len(fg) == temp_power_spectra.shape[0]: fg.power_spectra = temp_power_spectra + # Set the check data mode, as True if any of the inputs have it on, False otherwise + fg.set_check_data_mode(any([getattr(f_obj, '_check_data') for f_obj in fooofs])) + # Add data information information fg.add_meta_data(fooofs[0].get_meta_data()) diff --git a/fooof/tests/objs/test_fit.py b/fooof/tests/objs/test_fit.py index 430b55ee..a212c311 100644 --- a/fooof/tests/objs/test_fit.py +++ b/fooof/tests/objs/test_fit.py @@ -12,7 +12,7 @@ from fooof.core.items import OBJ_DESC from fooof.core.errors import FitError from fooof.core.utils import group_three -from fooof.sim import gen_power_spectrum +from fooof.sim import gen_freqs, gen_power_spectrum from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults from fooof.core.errors import DataError, NoDataError, InconsistentDataError @@ -364,7 +364,7 @@ def raise_runtime_error(*args, **kwargs): assert np.all(np.isnan(getattr(tfm, result))) def test_fooof_debug(): - """Test FOOOF fit failure in debug mode.""" + """Test FOOOF in debug mode, including with fit failures.""" tfm = FOOOF(verbose=False) tfm._maxfev = 5 @@ -374,3 +374,22 @@ def test_fooof_debug(): with raises(FitError): tfm.fit(*gen_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + +def test_fooof_check_data(): + """Test FOOOF in with check data mode turned off, including with NaN data.""" + + tfm = FOOOF(verbose=False) + + tfm.set_check_data_mode(False) + assert tfm._check_data is False + + # Add data, with check data turned off + # In check data mode, adding data with NaN should run + freqs = gen_freqs([3, 50], 0.5) + powers = np.ones_like(freqs) * np.nan + tfm.add_data(freqs, powers) + assert tfm.has_data + + # Model fitting should execute, but return a null model fit, given the NaNs, without failing + tfm.fit() + assert not tfm.has_model