Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix taper weighting in computation of TFR multitaper power #13067

Merged
merged 4 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13067.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.
19 changes: 12 additions & 7 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,25 @@ def test_tfr_morlet():
# computed within the method.
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)

# test that averaging power across tapers when multitaper with
# test that aggregating power across tapers when multitaper with
# output='complex' gives the same as output='power'
epoch_data = epochs.get_data()
multitaper_power = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
)
multitaper_complex = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
multitaper_complex, weights = tfr_array_multitaper(
epoch_data,
epochs.info["sfreq"],
freqs,
n_cycles,
output="complex",
return_weights=True,
)

taper_dim = 2
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
axis=taper_dim
)
weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data
tfr = weights * multitaper_complex
tfr = (tfr * tfr.conj()).real.sum(axis=2)
power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2))
assert_allclose(power_from_complex, multitaper_power)

print(itc) # test repr
Expand Down
40 changes: 22 additions & 18 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,20 +545,18 @@ def _compute_tfr(
if method == "morlet":
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
Ws = [W] # to have same dimensionality as the 'multitaper' case
weights = None # no tapers for Morlet estimates

elif method == "multitaper":
out = _make_dpss(
Ws, weights = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=return_weights,
return_weights=True, # required for converting complex → power
)
if return_weights:
Ws, weights = out
else:
Ws = out
weights = np.asarray(weights)

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -582,8 +580,6 @@ def _compute_tfr(
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
if return_weights:
weights = np.array(weights)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -594,7 +590,7 @@ def _compute_tfr(

# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
for channel in epoch_data.transpose(1, 0, 2)
)

Expand Down Expand Up @@ -683,7 +679,7 @@ def _check_tfr_param(
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim


def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
"""Aux. function to _compute_tfr.

Loops time-frequency transform across wavelets and epochs.
Expand All @@ -710,9 +706,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
weights : array, shape (n_tapers, n_wavelets) | None
Concentration weights for each taper in the wavelets, if present.
"""
# Set output type
dtype = np.float64
Expand All @@ -726,10 +721,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
n_freqs = len(Ws[0])
if ("avg_" in output) or ("itc" in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and method == "multitaper":
elif output in ["complex", "phase"] and weights is not None:
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
if weights is not None:
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension

# Loops across tapers.
for taper_idx, W in enumerate(Ws):
Expand All @@ -744,6 +741,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Loop across epochs
for epoch_idx, tfr in enumerate(coefs):
# Transform complex values
if output not in ["complex", "phase"] and weights is not None:
tfr = weights[taper_idx] * tfr # weight each taper estimate
if output in ["power", "avg_power"]:
tfr = (tfr * tfr.conj()).real # power
elif output == "phase":
Expand All @@ -759,7 +758,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Stack or add
if ("avg_" in output) or ("itc" in output):
tfrs += tfr
elif output in ["complex", "phase"] and method == "multitaper":
elif output in ["complex", "phase"] and weights is not None:
tfrs[taper_idx, epoch_idx] += tfr
else:
tfrs[epoch_idx] += tfr
Expand All @@ -774,9 +773,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
if ("avg_" in output) or ("itc" in output):
tfrs /= n_epochs

# Normalization by number of taper
if n_tapers > 1 and output not in ["complex", "phase"]:
tfrs /= n_tapers
# Normalization by taper weights
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
if "avg_" not in output: # add singleton epochs dimension to weights
weights = np.expand_dims(weights, axis=0)
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
if output == "avg_power_itc": # weight itc by the number of tapers
tfrs.imag = tfrs.imag / n_tapers

return tfrs


Expand Down
Loading