Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #13067 on branch maint/1.9 ([BUG] Fix taper weighting in computation of TFR multitaper power) #13072

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13067.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.
11 changes: 9 additions & 2 deletions mne/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def export_raw(

%(export_warning)s

.. warning::
When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the
same as ``raw.annotations.orig_time``. This guarantees that the annotations are
in the same reference frame as the samples. When
:attr:`Raw.first_time <mne.io.Raw.first_time>` is not zero (e.g., after
cropping), the onsets are automatically corrected so that onsets are always
relative to the first sample.

Parameters
----------
%(fname_export_params)s
Expand Down Expand Up @@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats):

supported_str = ", ".join(supported)
raise ValueError(
f"Format '{fmt}' is not supported. "
f"Supported formats are {supported_str}."
f"Format '{fmt}' is not supported. Supported formats are {supported_str}."
)
return fmt
16 changes: 0 additions & 16 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,6 @@ def test_tfr_morlet():
# computed within the method.
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)

# test that averaging power across tapers when multitaper with
# output='complex' gives the same as output='power'
epoch_data = epochs.get_data()
multitaper_power = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
)
multitaper_complex = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
)

taper_dim = 2
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
axis=taper_dim
)
assert_allclose(power_from_complex, multitaper_power)

print(itc) # test repr
print(itc.ch_names) # test property
itc += power # test add
Expand Down
52 changes: 33 additions & 19 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _make_dpss(
The wavelets time series.
"""
Ws = list()
Cs = list()

freqs = np.array(freqs)
if np.any(freqs <= 0):
Expand All @@ -281,6 +282,7 @@ def _make_dpss(

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
Expand All @@ -302,12 +304,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])

Wm.append(Wk)
Cm.append(Ck)

Ws.append(Wm)
Cs.append(Cm)
if return_weights:
return Ws, conc
return Ws, Cs
return Ws


Expand Down Expand Up @@ -529,15 +534,18 @@ def _compute_tfr(
if method == "morlet":
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
Ws = [W] # to have same dimensionality as the 'multitaper' case
weights = None # no tapers for Morlet estimates

elif method == "multitaper":
Ws = _make_dpss(
Ws, weights = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=True, # required for converting complex → power
)
weights = np.asarray(weights)

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -560,7 +568,7 @@ def _compute_tfr(
if ("avg_" in output) or ("itc" in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -571,7 +579,7 @@ def _compute_tfr(

# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
for channel in epoch_data.transpose(1, 0, 2)
)

Expand All @@ -581,10 +589,8 @@ def _compute_tfr(

if ("avg_" not in output) and ("itc" not in output):
# This is to enforce that the first dimension is for epochs
if output in ["complex", "phase"] and method == "multitaper":
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
out = np.moveaxis(out, 1, 0)

return out


Expand Down Expand Up @@ -658,7 +664,7 @@ def _check_tfr_param(
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim


def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
"""Aux. function to _compute_tfr.

Loops time-frequency transform across wavelets and epochs.
Expand All @@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
weights : array, shape (n_tapers, n_wavelets) | None
Concentration weights for each taper in the wavelets, if present.
"""
# Set output type
dtype = np.float64
Expand All @@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
n_freqs = len(Ws[0])
if ("avg_" in output) or ("itc" in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and method == "multitaper":
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and weights is not None:
tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
if weights is not None:
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension

# Loops across tapers.
for taper_idx, W in enumerate(Ws):
Expand All @@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Loop across epochs
for epoch_idx, tfr in enumerate(coefs):
# Transform complex values
if output not in ["complex", "phase"] and weights is not None:
tfr = weights[taper_idx] * tfr # weight each taper estimate
if output in ["power", "avg_power"]:
tfr = (tfr * tfr.conj()).real # power
elif output == "phase":
Expand All @@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Stack or add
if ("avg_" in output) or ("itc" in output):
tfrs += tfr
elif output in ["complex", "phase"] and method == "multitaper":
tfrs[taper_idx, epoch_idx] += tfr
elif output in ["complex", "phase"] and weights is not None:
tfrs[epoch_idx, taper_idx] += tfr
else:
tfrs[epoch_idx] += tfr

Expand All @@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
if ("avg_" in output) or ("itc" in output):
tfrs /= n_epochs

# Normalization by number of taper
if n_tapers > 1 and output not in ["complex", "phase"]:
tfrs /= n_tapers
# Normalization by taper weights
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
if "avg_" not in output: # add singleton epochs dimension to weights
weights = np.expand_dims(weights, axis=0)
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
if output == "avg_power_itc": # weight itc by the number of tapers
tfrs.imag = tfrs.imag / n_tapers

return tfrs


Expand Down
13 changes: 8 additions & 5 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):

docdict["export_fmt_support_epochs"] = """\
Supported formats:
- EEGLAB (``.set``, uses :mod:`eeglabio`)

- EEGLAB (``.set``, uses :mod:`eeglabio`)
"""

docdict["export_fmt_support_evoked"] = """\
Supported formats:
- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)

- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)
"""

docdict["export_fmt_support_raw"] = """\
Supported formats:
- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
- EEGLAB (``.set``, uses :mod:`eeglabio`)
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)

- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
- EEGLAB (``.set``, uses :mod:`eeglabio`)
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)
""" # noqa: E501

docdict["export_warning"] = """\
Expand Down
Loading