You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The mne.decoding.TimeFrequency transformer modifies constructor arguments, violating scikit-learn guidance on estimators. This leads to a cloning error when using the function in a pipeline. I was able to resolve the issue by moving the _check_tfr_param call to the transform method, in line with other checks performed at that time. See the changes made to mne.decoding.time_frequency.py
Traceback (most recent call last):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 822, in dispatch_one_batch
tasks = self._ready_batches.get(block=False)
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/queue.py", line 168, in get
raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-1fb399c6f0bb>", line 1, in <cell line: 1>
runfile('error.py', wdir='/Users/daniel/Documents/Coding_Projects/GitHub/mne-python')
File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "error.py", line 15, in <module>
mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))
File "<decorator-gen-447>", line 12, in cross_val_multiscore
File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 435, in cross_val_multiscore
scores = parallel(
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 1043, in __call__
if self.dispatch_one_batch(iterator):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 833, in dispatch_one_batch
islice = list(itertools.islice(iterator, big_batch_size))
File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 437, in <genexpr>
estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train,
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 87, in clone
new_object_params[name] = clone(param, safe=False)
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 96, in clone
raise RuntimeError(
RuntimeError: Cannot clone object TimeFrequency(None), as the constructor either does not set or modifies parameter n_cycles
Describe the bug
The
mne.decoding.TimeFrequency
transformer modifies constructor arguments, violating scikit-learn guidance on estimators. This leads to a cloning error when using the function in a pipeline. I was able to resolve the issue by moving the_check_tfr_param
call to the transform method, in line with other checks performed at that time. See the changes made to mne.decoding.time_frequency.pySteps to reproduce
Expected results
Successful completion of cross validation.
Actual results
Additional information
The text was updated successfully, but these errors were encountered: