Skip to content

Commit

Permalink
MRG, BUG: Fix combine evokeds (#7869)
Browse files Browse the repository at this point in the history
* simplify/clarify tests

* fix weights computation and docstring

* simplify grandaverage

* better warning

* remove redundant test; clarify existing test

* clarify tutorial

* fix subtractions

* clarify tutorial

* update what's new

* revert squeeze

* some tutorial fixes
  • Loading branch information
drammock authored Jun 7, 2020
1 parent 27341bb commit 4bfeb85
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 113 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Bug
- Fix bug with :func:`mne.preprocessing.compute_current_source_density` where values were not properly computed; maps should now be more focal, by `Alex Rockhill`_ and `Eric Larson`_
- Fix bug with :func:`mne.combine_evoked` where equal-weighted averages were wrongly computed as equal-weighted sums, by `Daniel McCloy`_
- Fix to enable interactive plotting with no colorbar with :func:`mne.viz.plot_evoked_topomap` by `Daniel McCloy`_
- Fix plotting with :func:`mne.viz.plot_evoked_topomap` to pre-existing axes by `Daniel McCloy`_
Expand Down
4 changes: 2 additions & 2 deletions examples/datasets/plot_limo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@

# Face A minus Face B
difference_wave = combine_evoked([limo_epochs['Face/A'].average(),
-limo_epochs['Face/B'].average()],
weights='equal')
limo_epochs['Face/B'].average()],
weights=[1, -1])

# plot difference wave
difference_wave.plot_joint(times=[0.15], title='Difference Face A - Face B')
Expand Down
4 changes: 2 additions & 2 deletions examples/stats/plot_linear_regression_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
time_unit='s')
epochs[cond].average().plot(axes=ax1, **params)
evokeds[cond].plot(axes=ax2, **params)
contrast = mne.combine_evoked([evokeds[cond], -epochs[cond].average()],
weights='equal')
contrast = mne.combine_evoked([evokeds[cond], epochs[cond].average()],
weights=[1, -1])
contrast.plot(axes=ax3, **params)
ax1.set_title("Traditional averaging")
ax2.set_title("rERF")
Expand Down
35 changes: 18 additions & 17 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,23 +837,26 @@ def _check_evokeds_ch_names_times(all_evoked):
def combine_evoked(all_evoked, weights):
"""Merge evoked data by weighted addition or subtraction.
Data should have the same channels and the same time instants.
Subtraction can be performed by calling
``combine_evoked([evoked1, -evoked2], 'equal')``
Each `~mne.Evoked` in ``all_evoked`` should have the same channels and the
same time instants. Subtraction can be performed by passing
``weights=[1, -1]``.
.. Warning::
If you provide an array of weights instead of using ``'equal'`` or
``'nave'``, strange things may happen with your resulting signal
amplitude and/or ``.nave`` attribute.
Other than cases like simple subtraction mentioned above (where all
weights are -1 or 1), if you provide numeric weights instead of using
``'equal'`` or ``'nave'``, the resulting `~mne.Evoked` object's
``.nave`` attribute (which is used to scale noise covariance when
applying the inverse operator) may not be suitable for inverse imaging.
Parameters
----------
all_evoked : list of Evoked
The evoked datasets.
weights : list of float | str
The weights to apply to the data of each evoked instance.
Can also be ``'nave'`` to weight according to evoked.nave,
or ``"equal"`` to use equal weighting (each weighted as ``1/N``).
weights : list of float | 'equal' | 'nave'
The weights to apply to the data of each evoked instance, or a string
describing the weighting strategy to apply: ``'nave'`` computes
sum-to-one weights proportional to each object's ``nave`` attribute;
``'equal'`` weights each `~mne.Evoked` by ``1 / len(all_evoked)``.
Returns
-------
Expand All @@ -870,16 +873,15 @@ def combine_evoked(all_evoked, weights):
if weights == 'nave':
weights = naves / naves.sum()
else:
weights = np.ones_like(naves)
weights = np.ones_like(naves) / len(naves)
else:
weights = np.array(weights, float)

if weights.ndim != 1 or weights.size != len(all_evoked):
raise ValueError('weights must be the same size as all_evoked')

# cf. https://en.wikipedia.org/wiki/Weighted_arithmetic_mean, section on
# how variances change when summing Gaussian random variables. The variance
# of a weighted sample mean is:
# "weighted sample variance". The variance of a weighted sample mean is:
#
# σ² = w₁² σ₁² + w₂² σ₂² + ... + wₙ² σₙ²
#
Expand All @@ -892,18 +894,17 @@ def combine_evoked(all_evoked, weights):
# This general formula is equivalent to formulae in Matti's manual
# (pp 128-129), where:
# new_nave = sum(naves) when weights='nave' and
# new_nave = 1. / sum(1. / naves) when weights='equal'
# new_nave = 1. / sum(1. / naves) when weights are all 1.

all_evoked = _check_evokeds_ch_names_times(all_evoked)
evoked = all_evoked[0].copy()

# use union of bad channels
bads = list(set(evoked.info['bads']).union(*(ev.info['bads']
for ev in all_evoked[1:])))
bads = list(set(b for e in all_evoked for b in e.info['bads']))
evoked.info['bads'] = bads
evoked.data = sum(w * e.data for w, e in zip(weights, all_evoked))
evoked.nave = new_nave
evoked.comment = ' + '.join('%0.3f * %s' % (w, e.comment or 'unknown')
evoked.comment = ' + '.join(f'{w:0.3f} × {e.comment or "unknown"}'
for w, e in zip(weights, all_evoked))
return evoked

Expand Down
10 changes: 4 additions & 6 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,12 +1158,10 @@ def test_evoked_arithmetic():
evoked2 = epochs2.average()
epochs = Epochs(raw, events[:8], event_id, tmin, tmax, picks=picks)
evoked = epochs.average()
evoked_sum = combine_evoked([evoked1, evoked2], weights='nave')
assert_array_equal(evoked.data, evoked_sum.data)
assert_array_equal(evoked.times, evoked_sum.times)
assert_equal(evoked_sum.nave, evoked1.nave + evoked2.nave)
evoked_diff = combine_evoked([evoked1, evoked1], weights=[1, -1])
assert_array_equal(np.zeros_like(evoked.data), evoked_diff.data)
evoked_avg = combine_evoked([evoked1, evoked2], weights='nave')
assert_array_equal(evoked.data, evoked_avg.data)
assert_array_equal(evoked.times, evoked_avg.times)
assert_equal(evoked_avg.nave, evoked1.nave + evoked2.nave)


def test_evoked_io_from_epochs(tmpdir):
Expand Down
99 changes: 42 additions & 57 deletions mne/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,64 +528,48 @@ def test_equalize_channels():
def test_arithmetic():
"""Test evoked arithmetic."""
ev = read_evokeds(fname, condition=0)
ev1 = EvokedArray(np.ones_like(ev.data), ev.info, ev.times[0], nave=20)
ev2 = EvokedArray(-np.ones_like(ev.data), ev.info, ev.times[0], nave=10)

# combine_evoked([ev1, ev2]) should be the same as ev1 + ev2:
# data should be added according to their `nave` weights
# nave = ev1.nave + ev2.nave
ev = combine_evoked([ev1, ev2], weights='nave')
assert_allclose(ev.nave, ev1.nave + ev2.nave)
assert_allclose(ev.data, 1. / 3. * np.ones_like(ev.data))

# with same trial counts, a bunch of things should be equivalent
for weights in ('nave', [0.5, 0.5]):
ev = combine_evoked([ev1, ev1], weights=weights)
assert_allclose(ev.data, ev1.data)
assert_allclose(ev.nave, 2 * ev1.nave)
ev = combine_evoked([ev1, -ev1], weights=weights)
assert_allclose(ev.data, 0., atol=1e-20)
assert_allclose(ev.nave, 2 * ev1.nave)
# adding evoked to itself
ev = combine_evoked([ev1, ev1], weights='equal')
assert_allclose(ev.data, 2 * ev1.data)
assert_allclose(ev.nave, ev1.nave / 2)
# subtracting evoked from itself
ev = combine_evoked([ev1, -ev1], weights='equal')
assert_allclose(ev.data, 0., atol=1e-20)
assert_allclose(ev.nave, ev1.nave / 2)
# subtracting different evokeds
ev = combine_evoked([ev1, -ev2], weights='equal')
assert_allclose(ev.data, 2., atol=1e-20)
expected_nave = 1. / (1. / ev1.nave + 1. / ev2.nave)
assert_allclose(ev.nave, expected_nave)
ev20 = EvokedArray(np.ones_like(ev.data), ev.info, ev.times[0], nave=20)
ev30 = EvokedArray(np.ones_like(ev.data), ev.info, ev.times[0], nave=30)

tol = dict(rtol=1e-9, atol=0)
# test subtraction
sub1 = combine_evoked([ev, ev], weights=[1, -1])
sub2 = combine_evoked([ev, -ev], weights=[1, 1])
assert np.allclose(sub1.data, np.zeros_like(sub1.data), atol=1e-20)
assert np.allclose(sub2.data, np.zeros_like(sub2.data), atol=1e-20)
# test nave weighting. Expect signal ampl.: 1*(20/50) + 1*(30/50) == 1
# and expect nave == ev1.nave + ev2.nave
ev = combine_evoked([ev20, ev30], weights='nave')
assert np.allclose(ev.nave, ev20.nave + ev30.nave)
assert np.allclose(ev.data, np.ones_like(ev.data), **tol)
# test equal-weighted sum. Expect signal ampl. == 2
# and expect nave == 1/sum(1/naves) == 1/(1/20 + 1/30) == 12
ev = combine_evoked([ev20, ev30], weights=[1, 1])
assert np.allclose(ev.nave, 12.)
assert np.allclose(ev.data, ev20.data + ev30.data, **tol)
# test equal-weighted average. Expect signal ampl. == 1
# and expect nave == 1/sum(weights²/naves) == 1/(0.5²/20 + 0.5²/30) == 48
ev = combine_evoked([ev20, ev30], weights='equal')
assert np.allclose(ev.nave, 48.)
assert np.allclose(ev.data, np.mean([ev20.data, ev30.data], axis=0), **tol)
# test zero weights
ev = combine_evoked([ev20, ev30], weights=[1, 0])
assert ev.nave == ev20.nave
assert np.allclose(ev.data, ev20.data, **tol)

# default comment behavior if evoked.comment is None
old_comment1 = ev1.comment
old_comment2 = ev2.comment
ev1.comment = None
ev = combine_evoked([ev1, -ev2], weights=[1, -1])
old_comment1 = ev20.comment
ev20.comment = None
ev = combine_evoked([ev20, -ev30], weights=[1, -1])
assert_equal(ev.comment.count('unknown'), 2)
assert ('-unknown' in ev.comment)
assert (' + ' in ev.comment)
ev1.comment = old_comment1
ev2.comment = old_comment2
ev20.comment = old_comment1

# equal weighting
ev = combine_evoked([ev1, ev2], weights='equal')
assert_allclose(ev.data, np.zeros_like(ev1.data))

# combine_evoked([ev1, ev2], weights=[1, 0]) should yield the same as ev1
ev = combine_evoked([ev1, ev2], weights=[1, 0])
assert_allclose(ev.nave, ev1.nave)
assert_allclose(ev.data, ev1.data)

# simple subtraction (like in oddball)
ev = combine_evoked([ev1, ev2], weights=[1, -1])
assert_allclose(ev.data, 2 * np.ones_like(ev1.data))

pytest.raises(ValueError, combine_evoked, [ev1, ev2], weights='foo')
pytest.raises(ValueError, combine_evoked, [ev1, ev2], weights=[1])
with pytest.raises(ValueError, match="Invalid value for the 'weights'"):
combine_evoked([ev20, ev30], weights='foo')
with pytest.raises(ValueError, match='weights must be the same size as'):
combine_evoked([ev20, ev30], weights=[1])

# grand average
evoked1, evoked2 = read_evokeds(fname, condition=[0, 1], proj=True)
Expand All @@ -597,8 +581,9 @@ def test_arithmetic():
assert_equal(gave.data.shape, [len(ch_names), evoked1.data.shape[1]])
assert_equal(ch_names, gave.ch_names)
assert_equal(gave.nave, 2)
pytest.raises(TypeError, grand_average, [1, evoked1])
gave = grand_average([ev1, ev1, ev2]) # (1 + 1 + -1) / 3 = 1/3
with pytest.raises(TypeError, match='All elements must be an instance of'):
grand_average([1, evoked1])
gave = grand_average([ev20, ev20, -ev30]) # (1 + 1 + -1) / 3 = 1/3
assert_allclose(gave.data, np.full_like(gave.data, 1. / 3.))

# test channel (re)ordering
Expand All @@ -608,10 +593,10 @@ def test_arithmetic():
evoked2.reorder_channels(evoked2.ch_names[::-1])
assert not np.allclose(data2, evoked2.data)
with pytest.warns(RuntimeWarning, match='reordering'):
ev3 = combine_evoked([evoked1, evoked2], weights=[0.5, 0.5])
assert np.allclose(ev3.data, data)
evoked3 = combine_evoked([evoked1, evoked2], weights=[0.5, 0.5])
assert np.allclose(evoked3.data, data)
assert evoked1.ch_names != evoked2.ch_names
assert evoked1.ch_names == ev3.ch_names
assert evoked1.ch_names == evoked3.ch_names


def test_array_epochs():
Expand Down
4 changes: 1 addition & 3 deletions mne/utils/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,10 +590,8 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
all_inst = [inst.interpolate_bads() if len(inst.info['bads']) > 0
else inst for inst in all_inst]
from ..evoked import combine_evoked as combine
weights = [1. / len(all_inst)] * len(all_inst)
else: # isinstance(all_inst[0], AverageTFR):
from ..time_frequency.tfr import combine_tfr as combine
weights = 'equal'

if drop_bads:
bads = list({b for inst in all_inst for b in inst.info['bads']})
Expand All @@ -603,7 +601,7 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):

equalize_channels(all_inst, copy=False)
# make grand_average object using combine_[evoked/tfr]
grand_average = combine(all_inst, weights=weights)
grand_average = combine(all_inst, weights='equal')
# change the grand_average.nave to the number of Evokeds
grand_average.nave = len(all_inst)
# change comment field
Expand Down
9 changes: 4 additions & 5 deletions tutorials/evoked/plot_10_evoked_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,10 @@
# This can be accomplished by the function :func:`mne.combine_evoked`, which
# computes a weighted sum of the :class:`~mne.Evoked` objects given to it. The
# weights can be manually specified as a list or array of float values, or can
# be specified using the keyword ``'equal'`` (weight each :class:`~mne.Evoked`
# object by :math:`\frac{1}{N}`, where :math:`N` is the number of
# :class:`~mne.Evoked` objects given) or the keyword ``'nave'`` (weight each
# :class:`~mne.Evoked` object by the number of epochs that were averaged
# together to create it):
# be specified using the keyword ``'equal'`` (weight each `~mne.Evoked` object
# by :math:`\frac{1}{N}`, where :math:`N` is the number of `~mne.Evoked`
# objects given) or the keyword ``'nave'`` (weight each `~mne.Evoked` object
# proportional to the number of epochs averaged together to create it):

left_right_aud = mne.combine_evoked([left_aud, right_aud], weights='nave')
assert left_right_aud.nave == left_aud.nave + right_aud.nave
Expand Down
24 changes: 13 additions & 11 deletions tutorials/evoked/plot_eeg_erp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,29 @@

###############################################################################
# Next, we create averages of stimulation-left vs stimulation-right trials.
# We can use basic arithmetic to, for example, construct and plot
# difference ERPs.
# We can use negative weights in `mne.combine_evoked` to construct difference
# ERPs.

left, right = epochs["left"].average(), epochs["right"].average()

# create and plot difference ERP
joint_kwargs = dict(ts_args=dict(time_unit='s'),
topomap_args=dict(time_unit='s'))
mne.combine_evoked([left, -right], weights='equal').plot_joint(**joint_kwargs)
mne.combine_evoked([left, right], weights=[1, -1]).plot_joint(**joint_kwargs)

###############################################################################
# This is an equal-weighting difference. If you have imbalanced trial numbers,
# you could also consider either equalizing the number of events per
# condition (using
# :meth:`epochs.equalize_event_counts <mne.Epochs.equalize_event_counts>`).
# `epochs.equalize_event_counts <mne.Epochs.equalize_event_counts>`) or
# use weights proportional to the number of trials averaged together to create
# each `~mne.Evoked` (by passing ``weights='nave'`` to `~mne.combine_evoked`).
# As an example, first, we create individual ERPs for each condition.

aud_l = epochs["auditory", "left"].average()
aud_r = epochs["auditory", "right"].average()
vis_l = epochs["visual", "left"].average()
vis_r = epochs["visual", "right"].average()
aud_l = epochs["auditory/left"].average()
aud_r = epochs["auditory/right"].average()
vis_l = epochs["visual/left"].average()
vis_r = epochs["visual/right"].average()

all_evokeds = [aud_l, aud_r, vis_l, vis_r]
print(all_evokeds)
Expand All @@ -155,10 +157,10 @@
all_evokeds = [epochs[cond].average() for cond in sorted(event_id.keys())]
print(all_evokeds)

# Then, we construct and plot an unweighted average of left vs. right trials
# this way, too:
# Then, we can construct and plot an unweighted average of left vs. right
# trials this way, too:
mne.combine_evoked(
[aud_l, -aud_r, vis_l, -vis_r], weights='equal').plot_joint(**joint_kwargs)
all_evokeds, weights=[0.5, 0.5, -0.5, -0.5]).plot_joint(**joint_kwargs)

###############################################################################
# Often, it makes sense to store Evoked objects in a dictionary or a list -
Expand Down
9 changes: 4 additions & 5 deletions tutorials/intro/plot_10_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,11 @@

##############################################################################
# Evoked objects can also be combined to show contrasts between conditions,
# using the :func:`mne.combine_evoked` function. A simple difference can be
# generated by negating one of the :class:`~mne.Evoked` objects passed into the
# function. We'll then plot the difference wave at each sensor using
# :meth:`~mne.Evoked.plot_topo`:
# using the `mne.combine_evoked` function. A simple difference can be
# generated by passing ``weights=[1, -1]``. We'll then plot the difference wave
# at each sensor using `~mne.Evoked.plot_topo`:

evoked_diff = mne.combine_evoked([aud_evoked, -vis_evoked], weights='equal')
evoked_diff = mne.combine_evoked([aud_evoked, vis_evoked], weights=[1, -1])
evoked_diff.pick_types(meg='mag').plot_topo(color='r', legend=False)

##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion tutorials/preprocessing/plot_70_fnirs_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
vmin=vmin, vmax=vmax, colorbar=False,
**topomap_args)

evoked_diff = mne.combine_evoked([evoked_left, -evoked_right], weights='equal')
evoked_diff = mne.combine_evoked([evoked_left, evoked_right], weights=[1, -1])

evoked_diff.plot_topomap(ch_type='hbo', times=ts, axes=axes[0, 2:],
vmin=vmin, vmax=vmax, colorbar=True,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sample-datasets/plot_brainstorm_auditory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
# We can see the MMN effect more clearly by looking at the difference between
# the two conditions. P50 and N100 are no longer visible, but MMN/P200 and
# P300 are emphasised.
evoked_difference = combine_evoked([evoked_dev, -evoked_std], weights='equal')
evoked_difference = combine_evoked([evoked_dev, evoked_std], weights=[1, -1])
evoked_difference.plot(window_title='Difference', gfp=True, time_unit='s')

###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion tutorials/source-modeling/plot_dipole_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
**plot_params)

# Subtract predicted from measured data (apply equal weights)
diff = combine_evoked([evoked, -pred_evoked], weights='equal')
diff = combine_evoked([evoked, pred_evoked], weights=[1, -1])
plot_params['colorbar'] = True
diff.plot_topomap(time_format='Difference', axes=axes[2:], **plot_params)
fig.suptitle('Comparison of measured and predicted fields '
Expand Down
4 changes: 2 additions & 2 deletions tutorials/stats-sensor-space/plot_stats_cluster_erp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
# effects on the head.

# We need an evoked object to plot the image to be masked
evoked = mne.combine_evoked([long_words.average(), -short_words.average()],
weights='equal') # calculate difference wave
evoked = mne.combine_evoked([long_words.average(), short_words.average()],
weights=[1, -1]) # calculate difference wave
time_unit = dict(time_unit="s")
evoked.plot_joint(title="Long vs. short words", ts_args=time_unit,
topomap_args=time_unit) # show difference wave
Expand Down

0 comments on commit 4bfeb85

Please sign in to comment.