Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeFrequency Estimator modifies parameters in constructor #10971

Closed
Dod12 opened this issue Jul 28, 2022 · 3 comments · Fixed by #11004
Closed

TimeFrequency Estimator modifies parameters in constructor #10971

Dod12 opened this issue Jul 28, 2022 · 3 comments · Fixed by #11004
Labels
Milestone

Comments

@Dod12
Copy link
Contributor

Dod12 commented Jul 28, 2022

Describe the bug

The mne.decoding.TimeFrequency transformer modifies constructor arguments, violating scikit-learn guidance on estimators. This leads to a cloning error when using the function in a pipeline. I was able to resolve the issue by moving the _check_tfr_param call to the transform method, in line with other checks performed at that time. See the changes made to mne.decoding.time_frequency.py

Steps to reproduce

import mne
import numpy as np
from sklearn import pipeline, linear_model

tfr_data = np.ones((100, 10, 1000))

freqs = np.array([5.])

estimator = pipeline.make_pipeline(
    mne.decoding.TimeFrequency(freqs, 10, "morlet", freqs/5., output="power"),
    mne.decoding.Vectorizer(),
    linear_model.LogisticRegression(),
)

mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))

Expected results

Successful completion of cross validation.

Actual results

Traceback (most recent call last):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 822, in dispatch_one_batch
    tasks = self._ready_batches.get(block=False)
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/queue.py", line 168, in get
    raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-1fb399c6f0bb>", line 1, in <cell line: 1>
    runfile('error.py', wdir='/Users/daniel/Documents/Coding_Projects/GitHub/mne-python')
  File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "error.py", line 15, in <module>
    mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))
  File "<decorator-gen-447>", line 12, in cross_val_multiscore
  File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 435, in cross_val_multiscore
    scores = parallel(
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 1043, in __call__
    if self.dispatch_one_batch(iterator):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 833, in dispatch_one_batch
    islice = list(itertools.islice(iterator, big_batch_size))
  File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 437, in <genexpr>
    estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train,
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 87, in clone
    new_object_params[name] = clone(param, safe=False)
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 96, in clone
    raise RuntimeError(
RuntimeError: Cannot clone object TimeFrequency(None), as the constructor either does not set or modifies parameter n_cycles

Additional information

Platform:      macOS-11.6.6-x86_64-i386-64bit
Python:        3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:09:13) [Clang 13.0.1 ]
Executable:    /Users/daniel/miniconda3/envs/mne-python/bin/python
CPU:           i386: 4 cores
Memory:        16.0 GB
mne:           0.23.4
numpy:         1.22.4 {blas=NO_ATLAS_INFO, lapack=lapack}
scipy:         1.8.1
matplotlib:    3.5.2 {backend=module://backend_interagg}
sklearn:       1.1.1
numba:         0.55.2
nibabel:       4.0.1
nilearn:       0.6.2
dipy:          1.5.0
cupy:          Not found
pandas:        1.4.3
mayavi:        4.8.0
pyvista:       0.35.2 {pyvistaqt=0.9.0, OpenGL 4.1 ATI-4.6.21 via AMD Radeon R9 M295X OpenGL Engine}
vtk:           
PyQt5:         5.12.3
@Dod12 Dod12 added the BUG label Jul 28, 2022
@welcome
Copy link

welcome bot commented Jul 28, 2022

Hello! 👋 Thanks for opening your first issue here! ❤️ We will try to get back to you soon. 🚴🏽‍♂️

@larsoner
Copy link
Member

@Dod12 agreed this seems like a bug, would you be up for making a PR to fix it? The minimal example above is already a good start for a unit test!

@larsoner larsoner added this to the 1.2 milestone Jul 28, 2022
@Dod12
Copy link
Contributor Author

Dod12 commented Jul 29, 2022

@larsoner Sure, I'll work on the tests over the weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants