Skip to content

Commit

Permalink
[BUG] Fix taper weighting in computation of TFR multitaper power (mne…
Browse files Browse the repository at this point in the history
…-tools#13067)

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
2 people authored and qian-chu committed Jan 20, 2025
1 parent 24e43c4 commit 388a426
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13067.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.
19 changes: 12 additions & 7 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,25 @@ def test_tfr_morlet():
# computed within the method.
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)

# test that averaging power across tapers when multitaper with
# test that aggregating power across tapers when multitaper with
# output='complex' gives the same as output='power'
epoch_data = epochs.get_data()
multitaper_power = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
)
multitaper_complex = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
multitaper_complex, weights = tfr_array_multitaper(
epoch_data,
epochs.info["sfreq"],
freqs,
n_cycles,
output="complex",
return_weights=True,
)

taper_dim = 2
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
axis=taper_dim
)
weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data
tfr = weights * multitaper_complex
tfr = (tfr * tfr.conj()).real.sum(axis=2)
power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2))
assert_allclose(power_from_complex, multitaper_power)

print(itc) # test repr
Expand Down
51 changes: 26 additions & 25 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,20 +545,18 @@ def _compute_tfr(
if method == "morlet":
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
Ws = [W] # to have same dimensionality as the 'multitaper' case
weights = None # no tapers for Morlet estimates

elif method == "multitaper":
out = _make_dpss(
Ws, weights = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=return_weights,
return_weights=True, # required for converting complex → power
)
if return_weights:
Ws, weights = out
else:
Ws = out
weights = np.asarray(weights)

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -581,9 +579,7 @@ def _compute_tfr(
if ("avg_" in output) or ("itc" in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
if return_weights:
weights = np.array(weights)
out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -594,7 +590,7 @@ def _compute_tfr(

# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
for channel in epoch_data.transpose(1, 0, 2)
)

Expand All @@ -604,10 +600,7 @@ def _compute_tfr(

if ("avg_" not in output) and ("itc" not in output):
# This is to enforce that the first dimension is for epochs
if output in ["complex", "phase"] and method == "multitaper":
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
out = np.moveaxis(out, 1, 0)

if return_weights:
return out, weights
Expand Down Expand Up @@ -683,7 +676,7 @@ def _check_tfr_param(
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim


def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
"""Aux. function to _compute_tfr.
Loops time-frequency transform across wavelets and epochs.
Expand All @@ -710,9 +703,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
weights : array, shape (n_tapers, n_wavelets) | None
Concentration weights for each taper in the wavelets, if present.
"""
# Set output type
dtype = np.float64
Expand All @@ -726,10 +718,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
n_freqs = len(Ws[0])
if ("avg_" in output) or ("itc" in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and method == "multitaper":
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and weights is not None:
tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
if weights is not None:
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension

# Loops across tapers.
for taper_idx, W in enumerate(Ws):
Expand All @@ -744,6 +738,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Loop across epochs
for epoch_idx, tfr in enumerate(coefs):
# Transform complex values
if output not in ["complex", "phase"] and weights is not None:
tfr = weights[taper_idx] * tfr # weight each taper estimate
if output in ["power", "avg_power"]:
tfr = (tfr * tfr.conj()).real # power
elif output == "phase":
Expand All @@ -759,8 +755,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Stack or add
if ("avg_" in output) or ("itc" in output):
tfrs += tfr
elif output in ["complex", "phase"] and method == "multitaper":
tfrs[taper_idx, epoch_idx] += tfr
elif output in ["complex", "phase"] and weights is not None:
tfrs[epoch_idx, taper_idx] += tfr
else:
tfrs[epoch_idx] += tfr

Expand All @@ -774,9 +770,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
if ("avg_" in output) or ("itc" in output):
tfrs /= n_epochs

# Normalization by number of taper
if n_tapers > 1 and output not in ["complex", "phase"]:
tfrs /= n_tapers
# Normalization by taper weights
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
if "avg_" not in output: # add singleton epochs dimension to weights
weights = np.expand_dims(weights, axis=0)
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
if output == "avg_power_itc": # weight itc by the number of tapers
tfrs.imag = tfrs.imag / n_tapers

return tfrs


Expand Down

0 comments on commit 388a426

Please sign in to comment.