Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Remove _check_tfr_params from Transformer constructor #11004

Merged
merged 12 commits into from
Aug 15, 2022
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Enhancements

Bugs
~~~~
- Fix bug in :func:`mne.decoding.TimeFrequency` that prevented cloning if constructor arguments were modified (:gh:`11004` by :newcontrib:`Daniel Carlström Schad`)
- Document ``height`` and ``weight`` keys of ``subject_info`` entry in :class:`mne.Info` (:gh:`11019` by :newcontrib:`Sena Er`)

API changes
Expand Down
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@

.. _Dan Wakeman: https://github.com/dgwakeman

.. _Daniel Carlström Schad: https://github.com/Dod12

.. _Daniel McCloy: http://dan.mccloy.info

.. _Daniel Strohmeier: https://github.com/joewalter
Expand Down
5 changes: 5 additions & 0 deletions mne/decoding/tests/test_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def test_timefrequency():
pytest.raises(ValueError, TimeFrequency, freqs, output=output)
tf = clone(tf)

# Clone estimator
freqs_array = np.array(np.asarray(freqs))
tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.)
clone(tf)

# Fit
n_epochs, n_chans, n_times = 10, 2, 100
X = np.random.rand(n_epochs, n_chans, n_times)
Expand Down
19 changes: 11 additions & 8 deletions mne/decoding/time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from .mixin import TransformerMixin
from .base import BaseEstimator
from ..time_frequency.tfr import _compute_tfr, _check_tfr_param
from ..time_frequency.tfr import _compute_tfr
from ..utils import fill_doc, _check_option, verbose


Expand Down Expand Up @@ -62,10 +62,12 @@ class TimeFrequency(TransformerMixin, BaseEstimator):
@verbose
def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0,
time_bandwidth=None, use_fft=True, decim=1, output='complex',
n_jobs=None, *, verbose=None): # noqa: D102
freqs, sfreq, _, n_cycles, time_bandwidth, decim = \
_check_tfr_param(freqs, sfreq, method, True, n_cycles,
time_bandwidth, use_fft, decim, output)
n_jobs=1, verbose=None): # noqa: D102
"""Init TimeFrequency transformer."""
# Check non-average output
output = _check_option('output', output,
['complex', 'power', 'phase'])

self.freqs = freqs
self.sfreq = sfreq
self.method = method
Expand All @@ -74,9 +76,9 @@ def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0,
self.use_fft = use_fft
self.decim = decim
# Check that output is not an average metric (e.g. ITC)
self.output = _check_option('output', output,
['complex', 'power', 'phase'])
self.output = output
self.n_jobs = n_jobs
self.verbose = verbose

def fit_transform(self, X, y=None):
"""Time-frequency transform of times series along the last axis.
Expand Down Expand Up @@ -137,7 +139,8 @@ def transform(self, X):
# Compute time-frequency
Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method,
self.n_cycles, True, self.time_bandwidth,
self.use_fft, self.decim, self.output, self.n_jobs)
self.use_fft, self.decim, self.output, self.n_jobs,
self.verbose)

# Back to original shape
if not shape:
Expand Down